diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index b9c625c7332a229f0e2d9e51c0df152c7f40de07..e8b0b2f352bedf71aa3b39cd03046bcc6c847d60 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -108,6 +108,7 @@ build-examples: matrix: - EXAMPLE: - matmul_clamp_f16_f16_f16p + - matmul_clamp_f32_bf16p_bf16p - matmul_clamp_f32_qai8dxp_qsi4cxp - matmul_clamp_f32_qsi8d32p_qsi4c32p - matmul_clamp_f32_qai8dxp_qsi4c32p @@ -130,6 +131,7 @@ test-examples: matrix: - EXAMPLE: - matmul_clamp_f16_f16_f16p + - matmul_clamp_f32_bf16p_bf16p - matmul_clamp_f32_qai8dxp_qsi4cxp - matmul_clamp_f32_qsi8d32p_qsi4c32p - matmul_clamp_f32_qai8dxp_qsi4c32p diff --git a/CMakeLists.txt b/CMakeLists.txt index 642fce8d016c1163d42bc1b427e98102f37d7bae..4d846759531622d4d0962a8b8690fed45239e9d4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -88,6 +88,12 @@ set(KLEIDIAI_FILES_NEON_FP16 kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.c ) +set(KLEIDIAI_FILES_NEON_BF16 + kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p_f32_neon.c + kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.c + kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.c +) + set(KLEIDIAI_FILES_NEON kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.c kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.c @@ -137,6 +143,7 @@ target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_SCALAR}) if((CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") AND NOT MSVC) target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON}) target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_FP16}) + target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_BF16}) target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_DOTPROD}) target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_I8MM}) target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_SME}) @@ -145,6 +152,7 @@ if((CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL set_source_files_properties(${KLEIDIAI_FILES_SCALAR} PROPERTIES COMPILE_OPTIONS -march=armv8-a${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set_source_files_properties(${KLEIDIAI_FILES_NEON} PROPERTIES COMPILE_OPTIONS -march=armv8-a${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set_source_files_properties(${KLEIDIAI_FILES_NEON_FP16} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+fp16${KLEIDIAI_INTERNAL_EXTRA_ARCH}) + set_source_files_properties(${KLEIDIAI_FILES_NEON_BF16} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+bf16${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set_source_files_properties(${KLEIDIAI_FILES_NEON_DOTPROD} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+dotprod${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set_source_files_properties(${KLEIDIAI_FILES_NEON_I8MM} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+i8mm${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set_source_files_properties(${KLEIDIAI_FILES_SME} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+sve+sve2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) @@ -170,6 +178,7 @@ if(KLEIDIAI_BUILD_TESTS) test/common/printer.cpp test/common/int4.cpp test/common/compare.cpp + test/common/matmul_test_common.cpp test/common/matrix_portion.cpp test/common/rect.cpp test/common/round.cpp @@ -205,6 +214,7 @@ if(KLEIDIAI_BUILD_TESTS) test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp test/tests/matmul_clamp_f32_qsi8d32p_qsi4c32p_test.cpp test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp + test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp ) target_link_libraries(kleidiai_test diff --git a/README.md b/README.md index a139d11b1bd7cfc3d05746b4cd42d32ede52482c..2d963d480dc16346ea1ca854c5523f635636288d 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,7 @@ Some of the data types currently supported with the KleidiAI library are the fol |---------------------------------------------------------------------------------------------------------------------| ----------- | ----------- | | Floating-point 32-bit | f32 | | | Floating-point 16-bit | f16 | | +| Brain Floating-point 16-bit | bf16 | | | Quantized (q) Symmetric (s) Signed (i) 4-bit (4) Per-Channel (cx) quantization parameters | qsi4cx | An fp32 multiplier shared among all values of the same channel. `x` denotes the entirety of the channel | | Quantized (q) Asymmetric (a) Signed (i) 8-bit (8) Per-Dimension (dx) (for example, Per-Row) quantization parameters | qai8dx | An fp32 multiplier and a int32 zero offset shared among all values of the same dimension. | @@ -177,6 +178,20 @@ Some of the data types currently supported with the KleidiAI library are the fol
+ + Matrix-multiplication with LHS packed and RHS packed matrices + matmul_clamp_f32_bf16p_bf16p + + LHS: bf16p
+ RHS: bf16p
+ DST: f32
+ + + + + The packing function for the RHS and Lhs matrices is listed in the header file of the GEMM micro kernel.
+ +

How to build

diff --git a/examples/matmul_clamp_f32_bf16p_bf16p/CMakeLists.txt b/examples/matmul_clamp_f32_bf16p_bf16p/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..4b13a183a140c6ec1179ece1befa9c029010f318 --- /dev/null +++ b/examples/matmul_clamp_f32_bf16p_bf16p/CMakeLists.txt @@ -0,0 +1,36 @@ +# +# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +# +# SPDX-License-Identifier: Apache-2.0 +# + +cmake_minimum_required(VERSION 3.16) + +project(KleidiAI) + +set(CMAKE_CXX_STANDARD 17) +set(KLEIDIAI_PATH ../../) +set(MATMUL_PACK_PATH ${KLEIDIAI_PATH}/kai/ukernels/matmul/pack/) +set(MATMUL_PATH ${KLEIDIAI_PATH}/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/) + +# KleidiAI include directories +include_directories( + ${KLEIDIAI_PATH} + ${MATMUL_PACK_PATH} + ${MATMUL_PATH}) + +# Files requires to build the executable +add_executable(matmul_clamp_f32_bf16p_bf16p + matmul_clamp_f32_bf16p_bf16p.cpp + ${MATMUL_PATH}/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.c + ${MATMUL_PACK_PATH}/kai_lhs_quant_pack_bf16p_f32_neon.c + ${MATMUL_PACK_PATH}/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.c +) + +target_compile_options(matmul_clamp_f32_bf16p_bf16p + PRIVATE -march=armv8.2-a+bf16 +) + +target_compile_definitions(matmul_clamp_f32_bf16p_bf16p + PRIVATE $<$:KAI_DEBUG> +) diff --git a/examples/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p.cpp b/examples/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f11900c3a78dce544f1e4b457ae7e2d4c1fb1bcb --- /dev/null +++ b/examples/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p.cpp @@ -0,0 +1,321 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +// Example usage for matrix multiplication of two half-precision brain floating-point (BF16) matrices +// and the accumulation of the result into an FP32 destination matrix. +// +// The activations and the weights, stored in the LHS and RHS matrices respectively, are both non-transposed matrices. +// The matrix multiplication computation is performed using BF16 matrix multiply (BFMMLA) +// vector instructions present in the FEAT_BF16 ArmĀ® architecture feature. +// +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_BF16_SCALAR_ARITHMETIC) || \ + !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) +#error This file must be compiled for AArch64, FEAT_BF16. +#else +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +// Include micro-kernel variants +#include "kai/kai_common.h" +#include "kai_lhs_quant_pack_bf16p_f32_neon.h" +#include "kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.h" +#include "kai_matmul_clamp_f32_bf16p_bf16p_interface.h" +#include "kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.h" + +inline static float bf16_to_float(const uint16_t* v) { + const uint16_t uint_rep = *v; + return kai_cast_f32_bf16(uint_rep); +} + +namespace { +/// Micro-kernel interface +constexpr kai_matmul_clamp_f32_bf16p_bf16p_ukernel ukernel{ + kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + kai_run_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla}; + +/// @brief Truncate the 32-bit floating point number's least significant 16 mantissa bits +/// @param x floating-point number +/// @return truncated floating-point number +inline static float truncate(float x) { + uint32_t uval = (*reinterpret_cast(&x) & 0xffff0000); + return *reinterpret_cast(&uval); +} + +/// Reference implementation of matrix multiplication +static void run_matmul_ref( + size_t m, size_t n, size_t k, const float* lhs, const float* rhs, const float* bias, float* dst, float scalar_min, + float scalar_max) { + for (size_t row_idx = 0; row_idx < m; ++row_idx) { + for (size_t col_idx = 0; col_idx < n; ++col_idx) { + float acc = bias[col_idx]; + + for (size_t k_idx = 0; k_idx < k; ++k_idx) { + float lhs_val = truncate(lhs[row_idx * k + k_idx]); + float rhs_val = truncate(rhs[col_idx + n * k_idx]); + + acc += lhs_val * rhs_val; + } + + dst[row_idx * n + col_idx] = std::clamp(acc, scalar_min, scalar_max); + } + } +} + +/// Fills the matrix with incremental values +void fill_matrix(size_t num_rows, size_t num_cols, float* dst, const float weight) { + for (size_t i = 0; i < num_rows * num_cols; i++) { + dst[i] = float((i + 1) * weight); + } +} + +/// Print the matrix +void print_matrix(size_t num_rows, size_t num_cols, const char* name, const float* src) { + std::cout << name << " = [\n"; + for (size_t y = 0; y < num_rows; ++y) { + std::cout << " ["; + for (size_t x = 0; x < num_cols; ++x) { + std::cout << std::setprecision(2) << std::fixed << src[y * num_cols + x] << ", "; + } + std::cout << ("],\n"); + } + std::cout << ("]\n\n"); +} + +void print_matrix(size_t num_rows, size_t num_cols, const char* name, const uint16_t* src) { + std::cout << name << " = [\n"; + for (size_t y = 0; y < num_rows; ++y) { + std::cout << " ["; + for (size_t x = 0; x < num_cols; ++x) { + std::cout << std::setprecision(2) << std::fixed << bf16_to_float(&src[y * num_cols + x]) << ", "; + } + std::cout << ("],\n"); + } + std::cout << ("]\n\n"); +} + +void print_mixed_prec_matrix( + size_t num_rows, size_t num_cols, const char* name, const uint8_t* src, int nr, int stride) { + std::cout << name << " = [\n"; + for (size_t y = 0; y < num_rows; ++y) { + std::cout << " ["; + const uint8_t* src_row = src + stride * y; + for (size_t x = 0; x < num_cols; ++x) { + if (x >= nr) { + // print bfloat + const uint16_t* src_elm = + reinterpret_cast(src_row + nr * sizeof(float) + (x - nr) * sizeof(uint16_t)); + std::cout << std::setprecision(2) << std::fixed << bf16_to_float(src_elm) << ", "; + } else { + // print float + const float* src_elm = reinterpret_cast(src_row + x * sizeof(float)); + std::cout << std::setprecision(2) << std::fixed << *src_elm << ", "; + } + } + std::cout << ("],\n"); + } + std::cout << ("]\n\n"); +} + +void print_bf_matrix(size_t num_rows, size_t num_cols, const char* name, const float* src) { + std::cout << name << " = [\n"; + for (size_t y = 0; y < num_rows; ++y) { + std::cout << " ["; + for (size_t x = 0; x < num_cols; ++x) { + std::cout << std::setprecision(2) << std::fixed << truncate(src[y * num_cols + x]) << ", "; + } + std::cout << ("],\n"); + } + std::cout << ("]\n\n"); +} + +/// Verify the micro-kernel output matches the reference implementation +bool is_output_correct( + size_t num_rows, size_t num_cols, const float rel_tolerance, const float* ref, const float* act) { + bool is_valid = true; + + for (size_t i = 0; i < num_rows * num_cols; ++i) { + if (std::fabs(ref[i] - act[i]) / (act[i] + 1e-10) > rel_tolerance) { + const size_t x = i % num_cols; + const size_t y = i / num_cols; + + std::cout << std::setprecision(5) << std::fixed << "ERROR![" << y << "][" << x << "]: ref=" << ref[i] + << " vs. act=" << act[i] << "\n"; + + is_valid = false; + } + } + return is_valid; +} +} // namespace + +int main() { + // Parameters of the matrix multiplication. Change these values to see how the micro-kernels operate on different + // sized matrices + const size_t M = 25; // Rows of LHS and DST matrices + const size_t N = 28; // Columns of RHS and DST matrices, and length of the Bias vector. + const size_t K = 117; // Columns of LHS, rows of RHS matrices + + const size_t lhs_size = M * K; + const size_t rhs_size = N * K; + const size_t bias_size = N; + const size_t dst_size = M * N; + + // Allocate the memory + float* lhs = new float[lhs_size]; + float* rhs = new float[rhs_size]; + float* bias = new float[bias_size]; + + fill_matrix(M, K, lhs, 0.1); + fill_matrix(K, N, rhs, 0.1); + fill_matrix(1, N, bias, 1); + +#ifdef KAI_DEBUG + // std::cout << "Floats: " << std::endl; + print_matrix(M, K, "lhs", lhs); + print_matrix(K, N, "rhs", rhs); + print_matrix(1, N, "bias", bias); + + // Print bf16 converted values + print_bf_matrix(M, K, "lhs_bf", lhs); + print_bf_matrix(K, N, "rhs_bf", rhs); +#endif // KAI_DEBUG + + //----------- REFERENCE IMPLEMENTATION + //------------------------------------ + //------------------------------------ + float* dst_ref = new float[dst_size]; + + run_matmul_ref( + M, N, K, // Dimensions + lhs, // LHS buffer + rhs, // RHS buffer + bias, // Bias buffer + dst_ref, // DST + -FLT_MAX, FLT_MAX // Min and max for the clamp operation + ); + //----------- END REFERENCE IMPLEMENTATION + //------------------------------------ + //------------------------------------ + + //----------- MICRO-KERNELS TESTS + //------------------------------------ + //------------------------------------ + const size_t mr = ukernel.get_mr(); + const size_t nr = ukernel.get_nr(); + const size_t kr = ukernel.get_kr(); + const size_t sr = ukernel.get_sr(); + + // In a single row, we pack nr bias values followed by K rows of nr RHS values + const size_t rhs_packed_size = kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon(N, K, nr, kr); + uint8_t* rhs_packed = new uint8_t[rhs_packed_size]; + + const size_t lhs_stride = K * sizeof(float); + const size_t rhs_stride = N * sizeof(float); + const size_t dst_stride_row = N * sizeof(float); + const size_t dst_stride_col = sizeof(float); + + const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p_f32_neon(M, K, mr, kr, sr); + uint16_t* lhs_packed = new uint16_t[lhs_packed_size]; + + // Packing only needs to be performed once if the contents of the bias and RHS matrices are expected to be constant. + kai_run_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon( + 1, N, K, nr, kr, sr, // Packing arguments + rhs_stride, // RHS stride + rhs, // RHS + bias, // Bias + NULL, // Scale + rhs_packed, // RHS packed + 0, NULL); + + // The RHS and Bias buffers can be freed after packing, however we reuse them for the reference test below + +#ifdef KAI_DEBUG + const size_t rhs_packed_cols = nr + kai_roundup(K, kr) * nr; + + // Each col has nr floats and then K*nr bfloats + int rhs_packed_stride = nr * sizeof(float) + kai_roundup(K, kr) * nr * sizeof(uint16_t); + const size_t rhs_packed_rows = rhs_packed_size / rhs_packed_stride; + + print_mixed_prec_matrix(rhs_packed_rows, rhs_packed_cols, "rhs_packed", rhs_packed, nr, rhs_packed_stride); +#endif // KAI_DEBUG + + float* dst = new float[dst_size]; + + const auto timer_matmul_start = std::chrono::high_resolution_clock::now(); + + kai_run_lhs_quant_pack_bf16p_f32_neon(M, K, mr, kr, sr, 0 /* m_idx_start */, lhs, lhs_stride, lhs_packed); + + ukernel.run_matmul( + M, N, K, // Dimensions + lhs_packed, // LHS packed + rhs_packed, // RHS packed + dst, // DST + dst_stride_row, // DST stride (row) + dst_stride_col, // DST stride (col) + -FLT_MAX, FLT_MAX // Min and max for the clamp operation + ); + + const auto timer_matmul_end = std::chrono::high_resolution_clock::now(); + const auto time_matmul = + std::chrono::duration_cast(timer_matmul_end - timer_matmul_start); + + int ret = 0; + +#ifdef KAI_DEBUG + int num_lhs_rows = (M + mr - 1) / mr; + int num_lhs_cols = mr * kai_roundup(K, kr); + + print_matrix(num_lhs_rows, num_lhs_cols, "lhs_packed", lhs_packed); + print_matrix(M, N, "dst", dst); + print_matrix(M, N, "ref", dst_ref); +#endif // KAI_DEBUG + + constexpr float rel_tolerance = 0.02; // This value was chosen by experimentation + const bool is_valid = is_output_correct(M, N, rel_tolerance, dst_ref, dst); + + std::cout << "TEST[matmul_clamp_f32_bf16p_bf16p]\n"; + std::cout << "- ukernel: matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla\n"; + if (is_valid) { + std::cout << "- Status: PASSED\n"; + std::cout << "- Performance: " << time_matmul.count() << "ns\n"; + } else { + std::cout << "- Status: FAILED\n"; + ret = 1; + } + + //----------- END MICRO-KERNELS TESTS + //------------------------------------ + //------------------------------------ + + delete[] lhs; + delete[] rhs; + delete[] bias; + delete[] lhs_packed; + delete[] rhs_packed; + delete[] dst; + delete[] dst_ref; + + return ret; +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 7bcbc1a8cbffd30e9a44c39bc4bb7f2fa1b4fdcf..66a3c3868a97e70828efb60ef3f84d724c0e8ef2 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -7,6 +7,7 @@ load( "//:kai_defs.bzl", "kai_c_library", + "kai_cpu_bf16", "kai_cpu_dotprod", "kai_cpu_fp16", "kai_cpu_i8mm", @@ -32,6 +33,22 @@ kai_c_library( ], ) +kai_c_library( + name = "clamp_f32_bf16p_bf16p_interface", + hdrs = ["matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h"], + cpu_uarch = kai_cpu_bf16(), +) + +kai_c_library( + name = "clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla", + srcs = ["matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.c"], + hdrs = ["matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.h"], + cpu_uarch = kai_cpu_bf16(), + deps = [ + ":clamp_f32_bf16p_bf16p_interface", + ], +) + kai_c_library( name = "clamp_f32_f32_f32p", srcs = ["matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.c"], @@ -159,6 +176,13 @@ kai_c_library( cpu_uarch = kai_cpu_sme(), ) +kai_c_library( + name = "lhs_quant_pack_bf16p_f32_neon", + srcs = ["pack/kai_lhs_quant_pack_bf16p_f32_neon.c"], + hdrs = ["pack/kai_lhs_quant_pack_bf16p_f32_neon.h"], + cpu_uarch = kai_cpu_bf16(), +) + kai_c_library( name = "rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon", srcs = ["pack/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.c"], @@ -166,6 +190,13 @@ kai_c_library( cpu_uarch = kai_cpu_fp16(), ) +kai_c_library( + name = "rhs_quant_pack_kxn_bf16pbiasf32_f32_neon", + srcs = ["pack/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.c"], + hdrs = ["pack/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.h"], + cpu_uarch = kai_cpu_bf16(), +) + kai_c_library( name = "rhs_pack_kxn_f32pbiasf32_f32_f32_neon", srcs = ["pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.c"], @@ -306,6 +337,7 @@ kai_c_library( name = "matmul", deps = [ ":clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla", + ":clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla", ":clamp_f32_f32_f32p", ":clamp_f32_f32_f32pb_1x16vl_sme2_mla", ":clamp_f32_f32p_f32p", @@ -327,6 +359,7 @@ kai_c_library( ":clamp_f32_qsi8d32p_qsi4c32p_i8mm", ":clamp_f32_qsi8d32p_qsi4c32p_interface", ":lhs_pack_f32p2vlx1_f32_sme", + ":lhs_quant_pack_bf16p_f32_neon", ":lhs_quant_pack_qai8dxp_f32", ":lhs_quant_pack_qsi8d32p_f32", ":rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon", @@ -338,5 +371,6 @@ kai_c_library( ":rhs_pack_nxk_qsi4c32p_qsu4c32s1s0", ":rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0", ":rhs_pack_nxk_qsi4cxp_qs4cxs1s0", + ":rhs_quant_pack_kxn_bf16pbiasf32_f32_neon", ], ) diff --git a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.c b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.c new file mode 100644 index 0000000000000000000000000000000000000000..929e3753a90c40bb64072e34f2671d97b211bc53 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.c @@ -0,0 +1,578 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) +#error This file must be compiled for AArch64, FEAT_BF16. +#else // Architectural features check. + +#include "kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.h" + +#include +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_mr = 8; +static const size_t kai_nr = 12; +static const size_t kai_kr = 4; +static const size_t kai_sr = 1; + +size_t kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void) { + return kai_mr; +} + +size_t kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void) { + return kai_nr; +} + +size_t kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(size_t m_idx, size_t k) { + KAI_ASSUME(m_idx % kai_mr == 0); + + return m_idx * kai_roundup(k, kai_kr) * sizeof(uint16_t); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(size_t n_idx, size_t k) { + KAI_ASSUME(n_idx % kai_nr == 0); + + return n_idx * (sizeof(float) + kai_roundup(k, kai_kr) * sizeof(uint16_t)); +} + +size_t kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla( + size_t m_idx, size_t n_idx, size_t stride) { + KAI_ASSUME(m_idx % kai_mr == 0); + KAI_ASSUME(n_idx % kai_nr == 0); + + return m_idx * stride + n_idx * sizeof(float); +} + +size_t kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(size_t m, size_t n) { + return m * n * sizeof(float); +} + +void kai_run_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla( + size_t m, size_t n, size_t k, // + const void* lhs_packed, // + const void* rhs_packed, // + void* dst, size_t dst_stride_row, size_t dst_stride_col, // + float clamp_min, float clamp_max) { + KAI_ASSERT(dst_stride_col == sizeof(float)); + + const void* Apanel = lhs_packed; + void* Cpanel = dst; + size_t ldc = dst_stride_row / sizeof(float); + + size_t M = m; + + typedef struct { + float maxval; + float minval; + size_t N; + size_t K; + const void* Bpanel; + void* output_ptr; + } KernelArgs; + + KernelArgs ka; + + ka.N = n; + ka.K = kai_roundup(k, kai_kr) / kai_kr - 1; + + ka.Bpanel = rhs_packed; + + // Direct output. + ka.output_ptr = dst; + + // Clamping output. + ka.maxval = clamp_max; + ka.minval = clamp_min; + + __asm__ __volatile__( + "1:" // Height loop + "add x11, %x[Cpanel], %x[ldc], LSL #2\n" + "add x10, %x[Cpanel], %x[ldc], LSL #1\n" + "add x9, x11, %x[ldc], LSL #1\n" + "cmp %x[M], #0x8\n" + "add x28, %x[Cpanel], %x[ldc], LSL #3\n" + "add x27, %x[Cpanel], %x[ldc]\n" + "add x26, x10, %x[ldc]\n" + "add x25, x11, %x[ldc]\n" + "add x24, x9, %x[ldc]\n" + "bge 2f\n" + "cmp %x[M], #0x2\n" + "mov x24, %x[Cpanel]\n" + "csel x27, x27, %x[Cpanel], GE\n" + "csel x10, x10, %x[Cpanel], GT\n" + "cmp %x[M], #0x4\n" + "csel x26, x26, %x[Cpanel], GE\n" + "csel x11, x11, %x[Cpanel], GT\n" + "cmp %x[M], #0x6\n" + "csel x25, x25, %x[Cpanel], GE\n" + "csel x9, x9, %x[Cpanel], GT\n" + "2:" // all rows valid + "ldr x23, [%x[args_ptr], %[offsetof_N]]\n" + "ldr x22, [%x[args_ptr], %[offsetof_Bpanel]]\n" + "mov x21, %x[Apanel]\n" + "3:" // Width loop + "ldr q4, [x22, #0x0]\n" + "ldr q5, [x22, #0x10]\n" + "mov %x[Apanel], x21\n" + "ldr q6, [x22, #0x20]\n" + "ldr x20, [%x[args_ptr], %[offsetof_K]]\n" + "add x22, x22, #0x30\n" + "ldr q7, [x22, #0x0]\n" + "ldr q0, [%x[Apanel], #0x0]\n" + "ldr q1, [%x[Apanel], #0x10]\n" + "zip1 v8.2d, v4.2d, v4.2d\n" + "ldr q2, [%x[Apanel], #0x20]\n" + "zip2 v11.2d, v4.2d, v4.2d\n" + "ldr q4, [x22, #0x10]\n" + "zip1 v9.2d, v5.2d, v5.2d\n" + "zip2 v12.2d, v5.2d, v5.2d\n" + "cmp x20, #0x2\n" + "zip1 v10.2d, v6.2d, v6.2d\n" + "zip2 v13.2d, v6.2d, v6.2d\n" + "prfm pldl1keep, [%x[Apanel], #0x0]\n" + "mov v14.16b, v8.16b\n" + "mov v17.16b, v11.16b\n" + "prfm pldl1keep, [x22, #0x0]\n" + "mov v15.16b, v9.16b\n" + "mov v18.16b, v12.16b\n" + "prfm pldl1keep, [x22, #0x40]\n" + "mov v16.16b, v10.16b\n" + "mov v19.16b, v13.16b\n" + "prfm pldl1keep, [%x[Apanel], #0x40]\n" + "mov v20.16b, v8.16b\n" + "mov v21.16b, v9.16b\n" + "prfm pldl1keep, [x22, #0x80]\n" + "mov v22.16b, v10.16b\n" + "mov v23.16b, v11.16b\n" + "prfm pldl1keep, [%x[Apanel], #0x80]\n" + "mov v24.16b, v12.16b\n" + "mov v25.16b, v13.16b\n" + "prfm pldl1keep, [x22, #0xc0]\n" + "mov v26.16b, v8.16b\n" + "mov v27.16b, v9.16b\n" + "prfm pldl1keep, [x22, #0x100]\n" + "mov v28.16b, v10.16b\n" + "mov v29.16b, v11.16b\n" + "prfm pldl1keep, [%x[Apanel], #0xc0]\n" + "mov v30.16b, v12.16b\n" + "mov v31.16b, v13.16b\n" + "prfm pldl1keep, [x22, #0x140]\n" + "add x22, x22, #0x20\n" + "add %x[Apanel], %x[Apanel], #0x30\n" + "blt 5f\n" + "4:" // main loop head + "ldr q3, [%x[Apanel], #0x0]\n" + "ldr q5, [x22, #0x0]\n" + ".inst 0x6e47ec08 // bfmmla v8.4s, v0.8h, v7.8h\n" + "ldr q6, [x22, #0x10]\n" + ".inst 0x6e44ec0b // bfmmla v11.4s, v0.8h, v4.8h\n" + ".inst 0x6e47ec2e // bfmmla v14.4s, v1.8h, v7.8h\n" + ".inst 0x6e44ec31 // bfmmla v17.4s, v1.8h, v4.8h\n" + ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n" + "sub x20, x20, #0x2\n" + ".inst 0x6e44ec57 // bfmmla v23.4s, v2.8h, v4.8h\n" + ".inst 0x6e47ec7a // bfmmla v26.4s, v3.8h, v7.8h\n" + "ldr q7, [x22, #0x20]\n" + ".inst 0x6e44ec7d // bfmmla v29.4s, v3.8h, v4.8h\n" + "ldr q4, [x22, #0x30]\n" + ".inst 0x6e45ec09 // bfmmla v9.4s, v0.8h, v5.8h\n" + ".inst 0x6e46ec0c // bfmmla v12.4s, v0.8h, v6.8h\n" + ".inst 0x6e45ec2f // bfmmla v15.4s, v1.8h, v5.8h\n" + "cmp x20, #0x2\n" + ".inst 0x6e46ec32 // bfmmla v18.4s, v1.8h, v6.8h\n" + ".inst 0x6e45ec55 // bfmmla v21.4s, v2.8h, v5.8h\n" + "prfm pldl1keep, [%x[Apanel], #0x100]\n" + ".inst 0x6e46ec58 // bfmmla v24.4s, v2.8h, v6.8h\n" + ".inst 0x6e45ec7b // bfmmla v27.4s, v3.8h, v5.8h\n" + "ldr q5, [x22, #0x40]\n" + ".inst 0x6e46ec7e // bfmmla v30.4s, v3.8h, v6.8h\n" + "ldr q6, [x22, #0x50]\n" + ".inst 0x6e47ec0a // bfmmla v10.4s, v0.8h, v7.8h\n" + ".inst 0x6e44ec0d // bfmmla v13.4s, v0.8h, v4.8h\n" + "ldr q0, [%x[Apanel], #0x10]\n" + ".inst 0x6e47ec30 // bfmmla v16.4s, v1.8h, v7.8h\n" + ".inst 0x6e44ec33 // bfmmla v19.4s, v1.8h, v4.8h\n" + "ldr q1, [%x[Apanel], #0x20]\n" + ".inst 0x6e47ec56 // bfmmla v22.4s, v2.8h, v7.8h\n" + ".inst 0x6e44ec59 // bfmmla v25.4s, v2.8h, v4.8h\n" + "ldr q2, [%x[Apanel], #0x30]\n" + ".inst 0x6e47ec7c // bfmmla v28.4s, v3.8h, v7.8h\n" + "ldr q7, [x22, #0x60]\n" + ".inst 0x6e44ec7f // bfmmla v31.4s, v3.8h, v4.8h\n" + "ldr q3, [%x[Apanel], #0x40]\n" + "ldr q4, [x22, #0x70]\n" + ".inst 0x6e45ec08 // bfmmla v8.4s, v0.8h, v5.8h\n" + ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n" + ".inst 0x6e45ec2e // bfmmla v14.4s, v1.8h, v5.8h\n" + ".inst 0x6e46ec31 // bfmmla v17.4s, v1.8h, v6.8h\n" + "prfm pldl1keep, [x22, #0x180]\n" + ".inst 0x6e45ec54 // bfmmla v20.4s, v2.8h, v5.8h\n" + ".inst 0x6e46ec57 // bfmmla v23.4s, v2.8h, v6.8h\n" + "prfm pldl1keep, [x22, #0x1c0]\n" + ".inst 0x6e45ec7a // bfmmla v26.4s, v3.8h, v5.8h\n" + "ldr q5, [x22, #0x80]\n" + ".inst 0x6e46ec7d // bfmmla v29.4s, v3.8h, v6.8h\n" + "ldr q6, [x22, #0x90]\n" + "prfm pldl1keep, [%x[Apanel], #0x140]\n" + ".inst 0x6e47ec09 // bfmmla v9.4s, v0.8h, v7.8h\n" + "prfm pldl1keep, [x22, #0x200]\n" + ".inst 0x6e44ec0c // bfmmla v12.4s, v0.8h, v4.8h\n" + ".inst 0x6e47ec2f // bfmmla v15.4s, v1.8h, v7.8h\n" + ".inst 0x6e44ec32 // bfmmla v18.4s, v1.8h, v4.8h\n" + ".inst 0x6e47ec55 // bfmmla v21.4s, v2.8h, v7.8h\n" + ".inst 0x6e44ec58 // bfmmla v24.4s, v2.8h, v4.8h\n" + ".inst 0x6e47ec7b // bfmmla v27.4s, v3.8h, v7.8h\n" + "ldr q7, [x22, #0xa0]\n" + ".inst 0x6e44ec7e // bfmmla v30.4s, v3.8h, v4.8h\n" + "ldr q4, [x22, #0xb0]\n" + ".inst 0x6e45ec0a // bfmmla v10.4s, v0.8h, v5.8h\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + "ldr q0, [%x[Apanel], #0x50]\n" + ".inst 0x6e45ec30 // bfmmla v16.4s, v1.8h, v5.8h\n" + ".inst 0x6e46ec33 // bfmmla v19.4s, v1.8h, v6.8h\n" + "ldr q1, [%x[Apanel], #0x60]\n" + ".inst 0x6e45ec56 // bfmmla v22.4s, v2.8h, v5.8h\n" + ".inst 0x6e46ec59 // bfmmla v25.4s, v2.8h, v6.8h\n" + "ldr q2, [%x[Apanel], #0x70]\n" + ".inst 0x6e45ec7c // bfmmla v28.4s, v3.8h, v5.8h\n" + ".inst 0x6e46ec7f // bfmmla v31.4s, v3.8h, v6.8h\n" + "add %x[Apanel], %x[Apanel], #0x80\n" + "add x22, x22, #0xc0\n" + "bge 4b\n" + "5:" // main loop skip + "ldr q3, [%x[Apanel], #0x0]\n" + "ldr q5, [x22, #0x0]\n" + ".inst 0x6e47ec08 // bfmmla v8.4s, v0.8h, v7.8h\n" + "ldr q6, [x22, #0x10]\n" + ".inst 0x6e44ec0b // bfmmla v11.4s, v0.8h, v4.8h\n" + ".inst 0x6e47ec2e // bfmmla v14.4s, v1.8h, v7.8h\n" + ".inst 0x6e44ec31 // bfmmla v17.4s, v1.8h, v4.8h\n" + ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n" + "add %x[Apanel], %x[Apanel], #0x10\n" + ".inst 0x6e44ec57 // bfmmla v23.4s, v2.8h, v4.8h\n" + ".inst 0x6e47ec7a // bfmmla v26.4s, v3.8h, v7.8h\n" + "ldr q7, [x22, #0x20]\n" + ".inst 0x6e44ec7d // bfmmla v29.4s, v3.8h, v4.8h\n" + "ldr q4, [x22, #0x30]\n" + ".inst 0x6e45ec09 // bfmmla v9.4s, v0.8h, v5.8h\n" + ".inst 0x6e46ec0c // bfmmla v12.4s, v0.8h, v6.8h\n" + ".inst 0x6e45ec2f // bfmmla v15.4s, v1.8h, v5.8h\n" + "add x22, x22, #0x40\n" + ".inst 0x6e46ec32 // bfmmla v18.4s, v1.8h, v6.8h\n" + ".inst 0x6e45ec55 // bfmmla v21.4s, v2.8h, v5.8h\n" + ".inst 0x6e46ec58 // bfmmla v24.4s, v2.8h, v6.8h\n" + ".inst 0x6e45ec7b // bfmmla v27.4s, v3.8h, v5.8h\n" + ".inst 0x6e46ec7e // bfmmla v30.4s, v3.8h, v6.8h\n" + ".inst 0x6e47ec0a // bfmmla v10.4s, v0.8h, v7.8h\n" + ".inst 0x6e44ec0d // bfmmla v13.4s, v0.8h, v4.8h\n" + ".inst 0x6e47ec30 // bfmmla v16.4s, v1.8h, v7.8h\n" + ".inst 0x6e44ec33 // bfmmla v19.4s, v1.8h, v4.8h\n" + ".inst 0x6e47ec56 // bfmmla v22.4s, v2.8h, v7.8h\n" + ".inst 0x6e44ec59 // bfmmla v25.4s, v2.8h, v4.8h\n" + ".inst 0x6e47ec7c // bfmmla v28.4s, v3.8h, v7.8h\n" + ".inst 0x6e44ec7f // bfmmla v31.4s, v3.8h, v4.8h\n" + "cbz x20, 6f\n" + "ldr q5, [x22, #0x0]\n" + "ldr q0, [%x[Apanel], #0x0]\n" + "ldr q1, [%x[Apanel], #0x10]\n" + "ldr q6, [x22, #0x10]\n" + "ldr q2, [%x[Apanel], #0x20]\n" + "ldr q3, [%x[Apanel], #0x30]\n" + "add %x[Apanel], %x[Apanel], #0x40\n" + "ldr q7, [x22, #0x20]\n" + "ldr q4, [x22, #0x30]\n" + ".inst 0x6e45ec08 // bfmmla v8.4s, v0.8h, v5.8h\n" + ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n" + ".inst 0x6e45ec2e // bfmmla v14.4s, v1.8h, v5.8h\n" + ".inst 0x6e46ec31 // bfmmla v17.4s, v1.8h, v6.8h\n" + ".inst 0x6e45ec54 // bfmmla v20.4s, v2.8h, v5.8h\n" + ".inst 0x6e46ec57 // bfmmla v23.4s, v2.8h, v6.8h\n" + ".inst 0x6e45ec7a // bfmmla v26.4s, v3.8h, v5.8h\n" + "ldr q5, [x22, #0x40]\n" + ".inst 0x6e46ec7d // bfmmla v29.4s, v3.8h, v6.8h\n" + "ldr q6, [x22, #0x50]\n" + ".inst 0x6e47ec09 // bfmmla v9.4s, v0.8h, v7.8h\n" + ".inst 0x6e44ec0c // bfmmla v12.4s, v0.8h, v4.8h\n" + ".inst 0x6e47ec2f // bfmmla v15.4s, v1.8h, v7.8h\n" + "add x22, x22, #0x60\n" + ".inst 0x6e44ec32 // bfmmla v18.4s, v1.8h, v4.8h\n" + ".inst 0x6e47ec55 // bfmmla v21.4s, v2.8h, v7.8h\n" + ".inst 0x6e44ec58 // bfmmla v24.4s, v2.8h, v4.8h\n" + ".inst 0x6e47ec7b // bfmmla v27.4s, v3.8h, v7.8h\n" + ".inst 0x6e44ec7e // bfmmla v30.4s, v3.8h, v4.8h\n" + ".inst 0x6e45ec0a // bfmmla v10.4s, v0.8h, v5.8h\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + ".inst 0x6e45ec30 // bfmmla v16.4s, v1.8h, v5.8h\n" + ".inst 0x6e46ec33 // bfmmla v19.4s, v1.8h, v6.8h\n" + ".inst 0x6e45ec56 // bfmmla v22.4s, v2.8h, v5.8h\n" + ".inst 0x6e46ec59 // bfmmla v25.4s, v2.8h, v6.8h\n" + ".inst 0x6e45ec7c // bfmmla v28.4s, v3.8h, v5.8h\n" + ".inst 0x6e46ec7f // bfmmla v31.4s, v3.8h, v6.8h\n" + "6:" // multiply loop done + "add x20, %x[args_ptr], %[offset_max]\n" + "uzp1 v7.2d, v8.2d, v11.2d\n" + "uzp2 v8.2d, v8.2d, v11.2d\n" + "ld1r { v1.4s }, [x20]\n" + "uzp1 v11.2d, v9.2d, v12.2d\n" + "uzp2 v9.2d, v9.2d, v12.2d\n" + "uzp1 v12.2d, v10.2d, v13.2d\n" + "uzp2 v10.2d, v10.2d, v13.2d\n" + "add x20, %x[args_ptr], %[offset_min]\n" + "ld1r { v0.4s }, [x20]\n" + "uzp1 v13.2d, v14.2d, v17.2d\n" + "uzp2 v14.2d, v14.2d, v17.2d\n" + "uzp1 v17.2d, v15.2d, v18.2d\n" + "uzp2 v15.2d, v15.2d, v18.2d\n" + "cmp x23, #0xc\n" + "uzp1 v18.2d, v16.2d, v19.2d\n" + "uzp2 v16.2d, v16.2d, v19.2d\n" + "uzp1 v19.2d, v20.2d, v23.2d\n" + "uzp2 v20.2d, v20.2d, v23.2d\n" + "uzp1 v23.2d, v21.2d, v24.2d\n" + "uzp2 v21.2d, v21.2d, v24.2d\n" + "uzp1 v24.2d, v22.2d, v25.2d\n" + "uzp2 v22.2d, v22.2d, v25.2d\n" + "uzp1 v25.2d, v26.2d, v29.2d\n" + "uzp2 v26.2d, v26.2d, v29.2d\n" + "uzp1 v29.2d, v27.2d, v30.2d\n" + "uzp2 v27.2d, v27.2d, v30.2d\n" + "uzp1 v30.2d, v28.2d, v31.2d\n" + "uzp2 v28.2d, v28.2d, v31.2d\n" + "fmin v7.4s, v7.4s, v1.4s\n" + "fmin v11.4s, v11.4s, v1.4s\n" + "fmin v12.4s, v12.4s, v1.4s\n" + "fmin v8.4s, v8.4s, v1.4s\n" + "fmin v9.4s, v9.4s, v1.4s\n" + "fmin v10.4s, v10.4s, v1.4s\n" + "fmin v13.4s, v13.4s, v1.4s\n" + "fmin v17.4s, v17.4s, v1.4s\n" + "fmin v18.4s, v18.4s, v1.4s\n" + "fmin v14.4s, v14.4s, v1.4s\n" + "fmin v15.4s, v15.4s, v1.4s\n" + "fmin v16.4s, v16.4s, v1.4s\n" + "fmin v19.4s, v19.4s, v1.4s\n" + "fmin v23.4s, v23.4s, v1.4s\n" + "fmin v24.4s, v24.4s, v1.4s\n" + "fmin v20.4s, v20.4s, v1.4s\n" + "fmin v21.4s, v21.4s, v1.4s\n" + "fmin v22.4s, v22.4s, v1.4s\n" + "fmin v25.4s, v25.4s, v1.4s\n" + "fmin v29.4s, v29.4s, v1.4s\n" + "fmin v30.4s, v30.4s, v1.4s\n" + "fmin v26.4s, v26.4s, v1.4s\n" + "fmin v27.4s, v27.4s, v1.4s\n" + "fmin v28.4s, v28.4s, v1.4s\n" + "fmax v7.4s, v7.4s, v0.4s\n" + "fmax v11.4s, v11.4s, v0.4s\n" + "fmax v12.4s, v12.4s, v0.4s\n" + "fmax v8.4s, v8.4s, v0.4s\n" + "fmax v9.4s, v9.4s, v0.4s\n" + "fmax v10.4s, v10.4s, v0.4s\n" + "fmax v13.4s, v13.4s, v0.4s\n" + "fmax v17.4s, v17.4s, v0.4s\n" + "fmax v18.4s, v18.4s, v0.4s\n" + "fmax v14.4s, v14.4s, v0.4s\n" + "fmax v15.4s, v15.4s, v0.4s\n" + "fmax v16.4s, v16.4s, v0.4s\n" + "fmax v19.4s, v19.4s, v0.4s\n" + "fmax v23.4s, v23.4s, v0.4s\n" + "fmax v24.4s, v24.4s, v0.4s\n" + "fmax v20.4s, v20.4s, v0.4s\n" + "fmax v21.4s, v21.4s, v0.4s\n" + "fmax v22.4s, v22.4s, v0.4s\n" + "fmax v25.4s, v25.4s, v0.4s\n" + "fmax v29.4s, v29.4s, v0.4s\n" + "fmax v30.4s, v30.4s, v0.4s\n" + "fmax v26.4s, v26.4s, v0.4s\n" + "fmax v27.4s, v27.4s, v0.4s\n" + "fmax v28.4s, v28.4s, v0.4s\n" + "blt 7f\n" + "str q26, [x24, #0x0]\n" + "str q27, [x24, #0x10]\n" + "str q28, [x24, #0x20]\n" + "add x24, x24, #0x30\n" + "str q25, [x9, #0x0]\n" + "str q29, [x9, #0x10]\n" + "str q30, [x9, #0x20]\n" + "add x9, x9, #0x30\n" + "str q20, [x25, #0x0]\n" + "str q21, [x25, #0x10]\n" + "str q22, [x25, #0x20]\n" + "add x25, x25, #0x30\n" + "str q19, [x11, #0x0]\n" + "str q23, [x11, #0x10]\n" + "str q24, [x11, #0x20]\n" + "add x11, x11, #0x30\n" + "str q14, [x26, #0x0]\n" + "str q15, [x26, #0x10]\n" + "str q16, [x26, #0x20]\n" + "add x26, x26, #0x30\n" + "str q13, [x10, #0x0]\n" + "str q17, [x10, #0x10]\n" + "str q18, [x10, #0x20]\n" + "add x10, x10, #0x30\n" + "str q8, [x27, #0x0]\n" + "str q9, [x27, #0x10]\n" + "str q10, [x27, #0x20]\n" + "add x27, x27, #0x30\n" + "str q7, [%x[Cpanel], #0x0]\n" + "str q11, [%x[Cpanel], #0x10]\n" + "str q12, [%x[Cpanel], #0x20]\n" + "add %x[Cpanel], %x[Cpanel], #0x30\n" + "b 14f\n" + "7:" // partial output + "tbz x23, #3, 9f\n" + "st1 { v26.4s }, [x24], #0x10\n" + "st1 { v27.4s }, [x24], #0x10\n" + "st1 { v25.4s }, [x9], #0x10\n" + "st1 { v29.4s }, [x9], #0x10\n" + "st1 { v20.4s }, [x25], #0x10\n" + "st1 { v21.4s }, [x25], #0x10\n" + "st1 { v19.4s }, [x11], #0x10\n" + "st1 { v23.4s }, [x11], #0x10\n" + "st1 { v14.4s }, [x26], #0x10\n" + "st1 { v15.4s }, [x26], #0x10\n" + "st1 { v13.4s }, [x10], #0x10\n" + "st1 { v17.4s }, [x10], #0x10\n" + "st1 { v8.4s }, [x27], #0x10\n" + "st1 { v9.4s }, [x27], #0x10\n" + "st1 { v7.4s }, [%x[Cpanel]], #0x10\n" + "st1 { v11.4s }, [%x[Cpanel]], #0x10\n" + "tbz x23, #1, 8f\n" + "str d28, [x24], #0x8\n" + "str d30, [x9], #0x8\n" + "str d22, [x25], #0x8\n" + "str d24, [x11], #0x8\n" + "str d16, [x26], #0x8\n" + "str d18, [x10], #0x8\n" + "str d10, [x27], #0x8\n" + "str d12, [%x[Cpanel]], #0x8\n" + "tbz x23, #0, 13f\n" + "st1 { v28.s }[2], [x24]\n" + "st1 { v30.s }[2], [x9]\n" + "st1 { v22.s }[2], [x25]\n" + "st1 { v24.s }[2], [x11]\n" + "st1 { v16.s }[2], [x26]\n" + "st1 { v18.s }[2], [x10]\n" + "st1 { v10.s }[2], [x27]\n" + "st1 { v12.s }[2], [%x[Cpanel]]\n" + "b 13f\n" + "8:" // partial result store: partial_1_8 + "tbz x23, #0, 13f\n" + "str s28, [x24, #0x0]\n" + "str s30, [x9, #0x0]\n" + "str s22, [x25, #0x0]\n" + "str s24, [x11, #0x0]\n" + "str s16, [x26, #0x0]\n" + "str s18, [x10, #0x0]\n" + "str s10, [x27, #0x0]\n" + "str s12, [%x[Cpanel], #0x0]\n" + "b 13f\n" + "9:" // partial result store: partial_4_0 + "tbz x23, #2, 11f\n" + "st1 { v26.4s }, [x24], #0x10\n" + "st1 { v25.4s }, [x9], #0x10\n" + "st1 { v20.4s }, [x25], #0x10\n" + "st1 { v19.4s }, [x11], #0x10\n" + "st1 { v14.4s }, [x26], #0x10\n" + "st1 { v13.4s }, [x10], #0x10\n" + "st1 { v8.4s }, [x27], #0x10\n" + "st1 { v7.4s }, [%x[Cpanel]], #0x10\n" + "tbz x23, #1, 10f\n" + "str d27, [x24], #0x8\n" + "str d29, [x9], #0x8\n" + "str d21, [x25], #0x8\n" + "str d23, [x11], #0x8\n" + "str d15, [x26], #0x8\n" + "str d17, [x10], #0x8\n" + "str d9, [x27], #0x8\n" + "str d11, [%x[Cpanel]], #0x8\n" + "tbz x23, #0, 13f\n" + "st1 { v27.s }[2], [x24]\n" + "st1 { v29.s }[2], [x9]\n" + "st1 { v21.s }[2], [x25]\n" + "st1 { v23.s }[2], [x11]\n" + "st1 { v15.s }[2], [x26]\n" + "st1 { v17.s }[2], [x10]\n" + "st1 { v9.s }[2], [x27]\n" + "st1 { v11.s }[2], [%x[Cpanel]]\n" + "b 13f\n" + "10:" // partial result store: partial_1_4 + "tbz x23, #0, 13f\n" + "str s27, [x24, #0x0]\n" + "str s29, [x9, #0x0]\n" + "str s21, [x25, #0x0]\n" + "str s23, [x11, #0x0]\n" + "str s15, [x26, #0x0]\n" + "str s17, [x10, #0x0]\n" + "str s9, [x27, #0x0]\n" + "str s11, [%x[Cpanel], #0x0]\n" + "b 13f\n" + "11:" // partial result store: partial_2_0 + "tbz x23, #1, 12f\n" + "str d26, [x24], #0x8\n" + "str d25, [x9], #0x8\n" + "str d20, [x25], #0x8\n" + "str d19, [x11], #0x8\n" + "str d14, [x26], #0x8\n" + "str d13, [x10], #0x8\n" + "str d8, [x27], #0x8\n" + "str d7, [%x[Cpanel]], #0x8\n" + "tbz x23, #0, 13f\n" + "st1 { v26.s }[2], [x24]\n" + "st1 { v25.s }[2], [x9]\n" + "st1 { v20.s }[2], [x25]\n" + "st1 { v19.s }[2], [x11]\n" + "st1 { v14.s }[2], [x26]\n" + "st1 { v13.s }[2], [x10]\n" + "st1 { v8.s }[2], [x27]\n" + "st1 { v7.s }[2], [%x[Cpanel]]\n" + "b 13f\n" + "12:" // partial result store: partial_1_0 + "str s26, [x24, #0x0]\n" + "str s25, [x9, #0x0]\n" + "str s20, [x25, #0x0]\n" + "str s19, [x11, #0x0]\n" + "str s14, [x26, #0x0]\n" + "str s13, [x10, #0x0]\n" + "str s8, [x27, #0x0]\n" + "str s7, [%x[Cpanel], #0x0]\n" + "13:" // partial result store: Done + "14:" // store done + "subs x23, x23, #0xc\n" + "bgt 3b\n" + "subs %x[M], %x[M], #0x8\n" + "mov %x[Cpanel], x28\n" + "bgt 1b\n" + : [Apanel] "+&r"(Apanel), [Cpanel] "+&r"(Cpanel), [M] "+&r"(M) + : [args_ptr] "r"(&ka), [ldc] "r"(ldc * sizeof(float)), [offset_max] "I"(offsetof(KernelArgs, maxval)), + [offset_min] "I"(offsetof(KernelArgs, minval)), [offsetof_Bpanel] "I"(offsetof(KernelArgs, Bpanel)), + [offsetof_K] "I"(offsetof(KernelArgs, K)), [offsetof_N] "I"(offsetof(KernelArgs, N)) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31", "x9", "x10", "x11", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.h b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.h new file mode 100644 index 0000000000000000000000000000000000000000..e870fb2a09706234cb1c82c0c837fa606d10b044 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.h @@ -0,0 +1,127 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) +#error This file must be compiled for AArch64, FEAT_BF16. +#else // Architectural features check. + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// -------------------------------------------------- + +/// Gets m step value. +/// +/// The starting row index must be divisible by `m_step`. +/// +/// @return The m step value. +size_t kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void); + +/// Gets n step value. +/// +/// The starting column index must be divisible by `n_step`. +/// +/// @return The n step value. +size_t kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void); + +/// Gets mr value. +/// +/// This is the packing parameter which must be used to pack the LHS matrix. +/// +/// @return The mr value. +size_t kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void); + +/// Gets nr value. +/// +/// This is the packing parameter which must be used to pack the RHS matrix. +/// +/// @return The nr value. +size_t kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void); + +/// Gets kr value. +/// +/// This is the packing parameter which must be used to pack the LHS & RHS matrices. +/// +/// @return The kr value. +size_t kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void); + +/// Gets sr value. +/// +/// This is the packing parameter which must be used to pack the RHS matrix. +/// +/// @return The sr value. +size_t kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void); + +/// Gets the offset in bytes to the data element in the packed LHS matrix buffer. +/// +/// @param[in] m_idx Row index. +/// @param[in] k Number of columns. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(size_t m_idx, size_t k); + +/// Gets the offset in bytes to the data element in the packed RHS matrix buffer. +/// +/// @param[in] n_idx Row index. +/// @param[in] k Number of columns. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(size_t n_idx, size_t k); + +/// Gets the offset in bytes to the data element in the destination matrix buffer. +/// +/// @param[in] m_idx Row index. +/// @param[in] n_idx Column index. +/// @param[in] stride Row stride in bytes. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(size_t m_idx, size_t n_idx, size_t stride); + +/// Gets the size in bytes of the destination matrix buffer. +/// +/// @param[in] m Number of rows. +/// @param[in] n Number of columns. +/// +/// @return The size in bytes of the destination matrix buffer. +size_t kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(size_t m, size_t n); + +/// Runs the matrix multiplication microkernel followed by a clamp operation. +/// +/// The pointer of each buffers (packed LHS, packed RHS and output) needs to be added with offset +/// calculated using the following functions: +/// +/// * Packed LHS: @ref kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla. +/// * Packed RHS: @ref kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla. +/// * Output: @ref kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla. +/// +/// @param[in] m Number of output rows to be computed. +/// @param[in] n Number of output columns to be computed. +/// @param[in] k Common dimension of the LHS and RHS operand. +/// @param[in] lhs_packed Packed LHS buffer. +/// @param[in] rhs_packed Packed RHS buffer. +/// @param[out] dst Output matrix buffer. +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(float) +/// @param[in] clamp_min Minimum value to clamp the final result. +/// @param[in] clamp_max Maximum value to clamp the final result. +void kai_run_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla( + size_t m, size_t n, size_t k, // + const void* lhs_packed, // + const void* rhs_packed, // + void* dst, size_t dst_stride_row, size_t dst_stride_col, // + float clamp_min, float clamp_max); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h new file mode 100644 index 0000000000000000000000000000000000000000..62f89279181dadb1e2dd58aae033fcc893b96ec1 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h @@ -0,0 +1,57 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) +#error This file must be compiled for AArch64, FEAT_BF16 +#else // Architectural features check. + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// All micro-kernels variants of the same type share the same interfaces +// In this case, the micro-kernel type is: matmul_clamp_f32_bf16p_bf16p + +/// Micro-kernel helper functions ("get" methods) +typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_m_step_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_n_step_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_mr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_nr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_kr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_sr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_lhs_packed_offset_func_t)(size_t m_idx, size_t k); +typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_rhs_packed_offset_func_t)(size_t n_idx, size_t k); +typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_dst_offset_func_t)(size_t m_idx, size_t n_idx, size_t dst_stride); +typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_dst_size_func_t)(size_t m, size_t n); + +/// Micro-kernel core function ("run" method) +typedef void (*kai_matmul_clamp_f32_bf16p_bf16p_run_matmul_func_t)( + size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row, + size_t dst_stride_col, float scalar_min, float scalar_max); + +/// Micro-kernel interface +struct kai_matmul_clamp_f32_bf16p_bf16p_ukernel { + kai_matmul_clamp_f32_bf16p_bf16p_get_m_step_func_t get_m_step; + kai_matmul_clamp_f32_bf16p_bf16p_get_n_step_func_t get_n_step; + kai_matmul_clamp_f32_bf16p_bf16p_get_mr_func_t get_mr; + kai_matmul_clamp_f32_bf16p_bf16p_get_nr_func_t get_nr; + kai_matmul_clamp_f32_bf16p_bf16p_get_nr_func_t get_kr; + kai_matmul_clamp_f32_bf16p_bf16p_get_sr_func_t get_sr; + kai_matmul_clamp_f32_bf16p_bf16p_get_lhs_packed_offset_func_t get_lhs_packed_offset; + kai_matmul_clamp_f32_bf16p_bf16p_get_rhs_packed_offset_func_t get_rhs_packed_offset; + kai_matmul_clamp_f32_bf16p_bf16p_get_dst_offset_func_t get_dst_offset; + kai_matmul_clamp_f32_bf16p_bf16p_get_dst_size_func_t get_dst_size; + kai_matmul_clamp_f32_bf16p_bf16p_run_matmul_func_t run_matmul; +}; + +#ifdef __cplusplus +} +#endif + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p_f32_neon.c b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p_f32_neon.c new file mode 100644 index 0000000000000000000000000000000000000000..cea886ed29a9212eae88cc3ab75f7aa646ed0e76 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p_f32_neon.c @@ -0,0 +1,200 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) +#error This file must be compiled for AArch64, FEAT_BF16. +#else // Architectural features check. + +#define MAX_MR 8 + +#include +#include +#include + +#include "kai/kai_common.h" + +size_t kai_get_m_step_lhs_quant_pack_bf16p_f32_neon(size_t mr) { + return mr; +} + +size_t kai_get_lhs_offset_lhs_quant_pack_bf16p_f32_neon(size_t m_idx, size_t lhs_stride) { + return m_idx * lhs_stride; +} + +size_t kai_get_lhs_packed_offset_lhs_quant_pack_bf16p_f32_neon( + size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) { + KAI_UNUSED(sr); + KAI_ASSUME(m_idx % mr == 0); + + return m_idx * kai_roundup(k, kr) * sizeof(uint16_t); +} + +size_t kai_get_lhs_packed_size_lhs_quant_pack_bf16p_f32_neon(size_t m, size_t k, size_t mr, size_t kr, size_t sr) { + KAI_UNUSED(sr); + + return kai_roundup(m, mr) * kai_roundup(k, kr) * sizeof(uint16_t); +} + +void kai_run_lhs_quant_pack_bf16p_f32_neon( + size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs, size_t lhs_stride, + uint16_t* lhs_packed) { + KAI_UNUSED(sr); + KAI_ASSUME(lhs != NULL); + KAI_ASSUME(lhs_packed != NULL); + + KAI_ASSUME(m_idx_start == 0); + KAI_ASSUME(mr <= MAX_MR); + + const size_t block_height = mr; + const size_t row_offset = 0; + + const void* in[MAX_MR]; + + for (size_t block_y = 0; block_y < m; block_y += block_height) { + const size_t height = KAI_MIN(m - block_y, block_height); + void* out = (char*)lhs_packed + block_y * kai_roundup(k, kr) * sizeof(uint16_t); + size_t width = k; + + for (size_t y = 0; y < height; y++) { + in[y] = (char*)lhs + (block_y + y) * lhs_stride; + } + + __asm__ __volatile__( + "ldr x28, [%x[in], #0x0]\n" + "ldr x27, [%x[in], #0x8]\n" + "cmp %x[height], #0x8\n" + "ldr x26, [%x[in], #0x10]\n" + "ldr x25, [%x[in], #0x18]\n" + "ldr x24, [%x[in], #0x20]\n" + "ldr x23, [%x[in], #0x28]\n" + "ldr x22, [%x[in], #0x30]\n" + "ldr x21, [%x[in], #0x38]\n" + "add x28, x28, %x[row_offset], LSL #2\n" + "add x27, x27, %x[row_offset], LSL #2\n" + "add x26, x26, %x[row_offset], LSL #2\n" + "add x25, x25, %x[row_offset], LSL #2\n" + "add x24, x24, %x[row_offset], LSL #2\n" + "add x23, x23, %x[row_offset], LSL #2\n" + "add x22, x22, %x[row_offset], LSL #2\n" + "add x21, x21, %x[row_offset], LSL #2\n" + "beq 1f\n" + "cmp %x[height], #0x2\n" + "mov x21, x28\n" + "csel x27, x27, x28, GE\n" + "csel x26, x26, x28, GT\n" + "cmp %x[height], #0x4\n" + "csel x25, x25, x28, GE\n" + "csel x24, x24, x28, GT\n" + "cmp %x[height], #0x6\n" + "csel x23, x23, x28, GE\n" + "csel x22, x22, x28, GT\n" + "1:" // no_pointer_adj + "cmp %x[width], #0x4\n" + "prfm pldl1keep, [x28, #0x0]\n" + "prfm pldl1keep, [x27, #0x0]\n" + "prfm pldl1keep, [x26, #0x0]\n" + "prfm pldl1keep, [x25, #0x0]\n" + "prfm pldl1keep, [x24, #0x0]\n" + "prfm pldl1keep, [x23, #0x0]\n" + "prfm pldl1keep, [x22, #0x0]\n" + "prfm pldl1keep, [x21, #0x0]\n" + "prfm pldl1keep, [x28, #0x40]\n" + "prfm pldl1keep, [x27, #0x40]\n" + "prfm pldl1keep, [x26, #0x40]\n" + "prfm pldl1keep, [x25, #0x40]\n" + "prfm pldl1keep, [x24, #0x40]\n" + "prfm pldl1keep, [x23, #0x40]\n" + "prfm pldl1keep, [x22, #0x40]\n" + "prfm pldl1keep, [x21, #0x40]\n" + "blt 3f\n" + "2:" // Main loop head + "ldr q19, [x28], #0x10\n" + "ldr q18, [x26], #0x10\n" + "subs %x[width], %x[width], #0x4\n" + "ldr q17, [x24], #0x10\n" + "ldr q16, [x22], #0x10\n" + "cmp %x[width], #0x4\n" + "ldr q23, [x27], #0x10\n" + "ldr q22, [x25], #0x10\n" + "ldr q21, [x23], #0x10\n" + "ldr q20, [x21], #0x10\n" + ".inst 0x0ea16a73 // bfcvtn v19.4h, v19.4s\n" + ".inst 0x0ea16a52 // bfcvtn v18.4h, v18.4s\n" + ".inst 0x0ea16a31 // bfcvtn v17.4h, v17.4s\n" + ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n" + "prfm pldl1keep, [x28, #0x70]\n" + "prfm pldl1keep, [x27, #0x70]\n" + "prfm pldl1keep, [x26, #0x70]\n" + "prfm pldl1keep, [x25, #0x70]\n" + "prfm pldl1keep, [x24, #0x70]\n" + "prfm pldl1keep, [x23, #0x70]\n" + ".inst 0x4ea16af3 // bfcvtn2 v19.8h, v23.4s\n" + ".inst 0x4ea16ad2 // bfcvtn2 v18.8h, v22.4s\n" + "prfm pldl1keep, [x22, #0x70]\n" + "prfm pldl1keep, [x21, #0x70]\n" + ".inst 0x4ea16ab1 // bfcvtn2 v17.8h, v21.4s\n" + ".inst 0x4ea16a90 // bfcvtn2 v16.8h, v20.4s\n" + "str q19, [%x[out_ptr], #0x0]\n" + "str q18, [%x[out_ptr], #0x10]\n" + "str q17, [%x[out_ptr], #0x20]\n" + "str q16, [%x[out_ptr], #0x30]\n" + "add %x[out_ptr], %x[out_ptr], #0x40\n" + "bge 2b\n" + "3:" // Main loop skip + "cbz %x[width], 6f\n" + "tbz %x[width], #1, 4f\n" + "ldr d19, [x28], #0x8\n" + "ldr d23, [x27], #0x8\n" + "mov x20, #0x1\n" + "ldr d18, [x26], #0x8\n" + "ldr d22, [x25], #0x8\n" + "ldr d17, [x24], #0x8\n" + "ldr d21, [x23], #0x8\n" + "ldr d16, [x22], #0x8\n" + "ldr d20, [x21], #0x8\n" + "tbz %x[width], #0, 5f\n" + "ld1 { v19.s }[2], [x28]\n" + "ld1 { v23.s }[2], [x27]\n" + "ld1 { v18.s }[2], [x26]\n" + "ld1 { v22.s }[2], [x25]\n" + "ld1 { v17.s }[2], [x24]\n" + "ld1 { v21.s }[2], [x23]\n" + "ld1 { v16.s }[2], [x22]\n" + "ld1 { v20.s }[2], [x21]\n" + "b 5f\n" + "4:" // odd_loads_1_0 + "ldr s19, [x28, #0x0]\n" + "ldr s23, [x27, #0x0]\n" + "mov x20, #0x1\n" + "ldr s18, [x26, #0x0]\n" + "ldr s22, [x25, #0x0]\n" + "ldr s17, [x24, #0x0]\n" + "ldr s21, [x23, #0x0]\n" + "ldr s16, [x22, #0x0]\n" + "ldr s20, [x21, #0x0]\n" + "5:" // Odd load end + ".inst 0x0ea16a73 // bfcvtn v19.4h, v19.4s\n" + ".inst 0x0ea16a52 // bfcvtn v18.4h, v18.4s\n" + ".inst 0x0ea16a31 // bfcvtn v17.4h, v17.4s\n" + ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n" + ".inst 0x4ea16af3 // bfcvtn2 v19.8h, v23.4s\n" + ".inst 0x4ea16ad2 // bfcvtn2 v18.8h, v22.4s\n" + ".inst 0x4ea16ab1 // bfcvtn2 v17.8h, v21.4s\n" + ".inst 0x4ea16a90 // bfcvtn2 v16.8h, v20.4s\n" + "str q19, [%x[out_ptr], #0x0]\n" + "str q18, [%x[out_ptr], #0x10]\n" + "str q17, [%x[out_ptr], #0x20]\n" + "str q16, [%x[out_ptr], #0x30]\n" + "add %x[out_ptr], %x[out_ptr], #0x40\n" + "6:" // Odds skip + : [out_ptr] "+&r"(out), [width] "+&r"(width) + : [height] "r"(height), [in] "r"(in), [row_offset] "r"(row_offset) + : "cc", "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "x20", "x21", "x22", "x23", "x24", + "x25", "x26", "x27", "x28"); + } +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p_f32_neon.h b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p_f32_neon.h new file mode 100644 index 0000000000000000000000000000000000000000..200a66c46ab394d8939fc8a8af55d0e4fa6219bd --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p_f32_neon.h @@ -0,0 +1,79 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +#include +#include + +#include "kai/kai_common.h" + +/// Gets m step value. +/// +/// The starting row index must be divisible by `m_step`. +/// +/// @param[in] mr Number of rows to be interleaved. +/// +/// @return The m step value. +size_t kai_get_m_step_lhs_quant_pack_bf16p_f32_neon(size_t mr); + +/// Gets the offset in bytes to the data element in the LHS buffer. +/// +/// @param[in] m_idx Row index. +/// @param[in] lhs_stride Row stride in bytes. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_lhs_offset_lhs_quant_pack_bf16p_f32_neon(size_t m_idx, size_t lhs_stride); + +/// Gets the offset in bytes to the data element in the packed LHS buffer. +/// +/// @param[in] m_idx Row index in the unpacked LHS matrix. +/// @param[in] k Number of columns in the unpacked LHS matrix. +/// @param[in] mr Number of rows to be interleaved. +/// @param[in] kr Number of columns to be interleaved. +/// @param[in] sr Unused. Must be 1. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_lhs_packed_offset_lhs_quant_pack_bf16p_f32_neon(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr); + +/// Gets the size in bytes of the packed LHS buffer. +/// +/// @param[in] m Number of rows in the unpacked LHS matrix. +/// @param[in] k Number of columns in the unpacked LHS matrix. +/// @param[in] mr Number of rows to be interleaved. +/// @param[in] kr Number of columns to be interleaved. +/// @param[in] sr Unused. Must be 1. +/// +/// @return The size in bytes of the packed LHS buffer. +size_t kai_get_lhs_packed_size_lhs_quant_pack_bf16p_f32_neon(size_t m, size_t k, size_t mr, size_t kr, size_t sr); + +/// Runs the LHS packing function for matrix multiplication. +/// +/// The pointer of each buffers (LHS and packed LHS) needs to be added with offset +/// calculated using the following functions: +/// +/// * LHS: @ref kai_get_lhs_offset_lhs_quant_pack_bf16p_f32_neon. +/// * Packed LHS: @ref kai_get_lhs_packed_offset_lhs_quant_pack_bf16p_f32_neon. +/// +/// @param[in] m Number of rows of the unpacked LHS matrix. +/// @param[in] k Common dimension between the LHS and RHS matrix. +/// @param[in] mr Block size in M dimension. It must be 8. +/// @param[in] kr Block size in K dimension. It must be 4. +/// @param[in] sr Number of kr splits. It must be 1. +/// @param[in] m_idx_start Unused. Must be 0. +/// @param[in] lhs LHS matrix data buffer. +/// @param[in] lhs_stride Row stride in bytes of the LHS matrix. +/// @param[out] lhs_packed Packed LHS matrix. +void kai_run_lhs_quant_pack_bf16p_f32_neon( + size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride, + void* lhs_packed); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.c b/kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.c new file mode 100644 index 0000000000000000000000000000000000000000..bf9eb10720dcda6a0339f16ca50ee217094fe5ad --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.c @@ -0,0 +1,464 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if !defined(__aarch64__) +#error This file must be compiled for AArch64. +#else // Architectural features check. + +#define MAX_NR 12 + +#include +#include +#include +#include + +#include "kai/kai_common.h" + +size_t kai_get_n_step_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon(size_t nr) { + return nr; +} + +size_t kai_get_rhs_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon(size_t n_idx) { + return n_idx * sizeof(float); +} + +size_t kai_get_bias_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon(size_t n_idx) { + return n_idx * sizeof(uint32_t); +} + +size_t kai_get_rhs_packed_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon( + size_t n_idx, size_t k, size_t nr, size_t kr) { + KAI_ASSUME(n_idx % nr == 0); + + return n_idx * (sizeof(uint32_t) + kai_roundup(k, kr) * sizeof(uint16_t)); +} + +size_t kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon(size_t n, size_t k, size_t nr, size_t kr) { + return kai_get_rhs_packed_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon(kai_roundup(n, nr), k, nr, kr); +} + +void kai_run_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, + const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params) { + KAI_ASSUME(num_groups == 1); + KAI_ASSUME(sr == 1); + KAI_ASSUME(rhs != NULL); + KAI_ASSUME(scale == NULL); + KAI_ASSUME(rhs_packed != NULL); + KAI_ASSUME(extra_bytes == 0); + KAI_ASSUME(params == NULL); + KAI_ASSUME(nr <= MAX_NR); + + size_t height = k; + const size_t width = n; + const void* in = (void*)rhs; + void* out = rhs_packed; + const size_t in_stride = rhs_stride; + const float* pad_row = rhs; + + // Fill zeros if bias is nullptr + size_t bias_step = nr * sizeof(float); + uint8_t zero_bias[MAX_NR * sizeof(float)]; + + if (bias == NULL) { + memset(zero_bias, 0, MAX_NR * sizeof(float)); + bias_step = 0; + } + + const void* bias_ptr = bias == NULL ? (void*)zero_bias : (void*)bias; + + const size_t out_stride = nr * kai_roundup(height, kr) * sizeof(uint16_t) + nr * sizeof(uint32_t); + + __asm__ __volatile__( + "mov x22, %x[width]\n" + "mov x21, %x[out]\n" + "cmp x22, #0xc\n" + "blt 2f\n" + "1:" // Bias: Full loop + "ldr q16, [%x[bias], #0x0]\n" + "ldr q26, [%x[bias], #0x10]\n" + "sub x22, x22, #0xc\n" + "ldr q8, [%x[bias], #0x20]\n" + "cmp x22, #0xc\n" + "add %x[bias], %x[bias], %x[bias_step]\n" + "str q16, [x21, #0x0]\n" + "str q26, [x21, #0x10]\n" + "str q8, [x21, #0x20]\n" + "add x21, x21, %x[out_stride]\n" + "bge 1b\n" + "cbz x22, 3f\n" + "2:" // Bias: Tail loop + "ldr w20, [%x[bias], #0x0]\n" + "sub x22, x22, #0x1\n" + "add %x[bias], %x[bias], #0x4\n" + "cmp x22, #0x0\n" + "str w20, [x21]\n" + "add x21, x21, #0x4\n" + "bgt 2b\n" + "3:" // Bias: Done + "cmp %x[height], #0x8\n" + "add %x[out], %x[out], #0x30\n" + "blt 12f\n" + "4:" // Main row loop: Head + "mov x9, %x[in]\n" + "mov x28, %x[width]\n" + "mov x27, %x[out]\n" + "sub %x[height], %x[height], #0x8\n" + "add x26, x9, %x[in_stride]\n" + "add x25, x26, %x[in_stride]\n" + "add x24, x25, %x[in_stride]\n" + "cmp x28, #0xc\n" + "add x23, x24, %x[in_stride]\n" + "add x22, x23, %x[in_stride]\n" + "add x21, x22, %x[in_stride]\n" + "add x20, x21, %x[in_stride]\n" + "add %x[in], x20, %x[in_stride]\n" + "blt 6f\n" + "5:" // Main row loop: Column loop + "ldr q28, [x9], #0x10\n" + "ldr q27, [x26], #0x10\n" + "sub x28, x28, #0xc\n" + "ldr q11, [x25], #0x10\n" + "ldr q5, [x24], #0x10\n" + "cmp x28, #0xc\n" + "ldr q14, [x23], #0x10\n" + "ldr q6, [x22], #0x10\n" + "ldr q2, [x21], #0x10\n" + "ldr q18, [x20], #0x10\n" + "ldr q1, [x9], #0x10\n" + "ldr q7, [x26], #0x10\n" + "zip1 v15.4s, v28.4s, v11.4s\n" + "zip1 v8.4s, v27.4s, v5.4s\n" + "ldr q3, [x25], #0x10\n" + "ldr q23, [x24], #0x10\n" + "zip2 v17.4s, v28.4s, v11.4s\n" + "zip2 v27.4s, v27.4s, v5.4s\n" + "ldr q5, [x23], #0x10\n" + "ldr q30, [x22], #0x10\n" + "zip1 v26.4s, v14.4s, v2.4s\n" + "zip1 v31.4s, v6.4s, v18.4s\n" + "ldr q20, [x21], #0x10\n" + "ldr q16, [x20], #0x10\n" + "zip2 v12.4s, v14.4s, v2.4s\n" + "zip2 v24.4s, v6.4s, v18.4s\n" + "ldr q29, [x9], #0x10\n" + "ldr q6, [x26], #0x10\n" + "zip1 v18.4s, v1.4s, v3.4s\n" + "zip1 v4.4s, v7.4s, v23.4s\n" + "ldr q22, [x25], #0x10\n" + "ldr q0, [x24], #0x10\n" + "zip2 v3.4s, v1.4s, v3.4s\n" + "zip2 v1.4s, v7.4s, v23.4s\n" + "ldr q2, [x23], #0x10\n" + "ldr q10, [x22], #0x10\n" + "zip1 v28.4s, v5.4s, v20.4s\n" + "zip1 v14.4s, v30.4s, v16.4s\n" + "ldr q9, [x21], #0x10\n" + "ldr q23, [x20], #0x10\n" + "zip2 v13.4s, v5.4s, v20.4s\n" + "zip2 v30.4s, v30.4s, v16.4s\n" + "zip1 v16.4s, v29.4s, v22.4s\n" + "zip1 v5.4s, v6.4s, v0.4s\n" + "zip2 v22.4s, v29.4s, v22.4s\n" + "zip2 v0.4s, v6.4s, v0.4s\n" + "zip1 v7.4s, v2.4s, v9.4s\n" + "zip1 v19.4s, v10.4s, v23.4s\n" + "zip2 v21.4s, v2.4s, v9.4s\n" + "zip2 v25.4s, v10.4s, v23.4s\n" + "zip1 v11.4s, v15.4s, v8.4s\n" + "zip1 v9.4s, v17.4s, v27.4s\n" + "zip1 v6.4s, v18.4s, v4.4s\n" + "zip1 v2.4s, v3.4s, v1.4s\n" + "zip1 v29.4s, v16.4s, v5.4s\n" + "zip1 v20.4s, v22.4s, v0.4s\n" + "zip1 v10.4s, v26.4s, v31.4s\n" + "zip1 v23.4s, v12.4s, v24.4s\n" + ".inst 0x0ea1696b // bfcvtn v11.4h, v11.4s\n" + "zip2 v8.4s, v15.4s, v8.4s\n" + "zip1 v15.4s, v28.4s, v14.4s\n" + ".inst 0x0ea16929 // bfcvtn v9.4h, v9.4s\n" + "zip2 v27.4s, v17.4s, v27.4s\n" + "zip1 v17.4s, v13.4s, v30.4s\n" + ".inst 0x0ea168c6 // bfcvtn v6.4h, v6.4s\n" + "zip2 v4.4s, v18.4s, v4.4s\n" + "zip1 v18.4s, v7.4s, v19.4s\n" + ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n" + "zip2 v1.4s, v3.4s, v1.4s\n" + "zip1 v3.4s, v21.4s, v25.4s\n" + ".inst 0x0ea16bbd // bfcvtn v29.4h, v29.4s\n" + "zip2 v5.4s, v16.4s, v5.4s\n" + ".inst 0x0ea16a94 // bfcvtn v20.4h, v20.4s\n" + "zip2 v0.4s, v22.4s, v0.4s\n" + ".inst 0x0ea16956 // bfcvtn v22.4h, v10.4s\n" + "zip2 v31.4s, v26.4s, v31.4s\n" + ".inst 0x0ea16aea // bfcvtn v10.4h, v23.4s\n" + "zip2 v26.4s, v12.4s, v24.4s\n" + ".inst 0x0ea169ef // bfcvtn v15.4h, v15.4s\n" + "zip2 v12.4s, v28.4s, v14.4s\n" + ".inst 0x0ea16a2e // bfcvtn v14.4h, v17.4s\n" + "zip2 v24.4s, v13.4s, v30.4s\n" + ".inst 0x0ea16a57 // bfcvtn v23.4h, v18.4s\n" + "zip2 v18.4s, v7.4s, v19.4s\n" + ".inst 0x0ea16871 // bfcvtn v17.4h, v3.4s\n" + "zip2 v16.4s, v21.4s, v25.4s\n" + ".inst 0x4ea1690b // bfcvtn2 v11.8h, v8.4s\n" + ".inst 0x4ea16b69 // bfcvtn2 v9.8h, v27.4s\n" + ".inst 0x4ea16886 // bfcvtn2 v6.8h, v4.4s\n" + ".inst 0x4ea16822 // bfcvtn2 v2.8h, v1.4s\n" + ".inst 0x4ea168bd // bfcvtn2 v29.8h, v5.4s\n" + ".inst 0x4ea16814 // bfcvtn2 v20.8h, v0.4s\n" + ".inst 0x4ea16bf6 // bfcvtn2 v22.8h, v31.4s\n" + ".inst 0x4ea16b4a // bfcvtn2 v10.8h, v26.4s\n" + "str q11, [x27, #0x0]\n" + ".inst 0x4ea1698f // bfcvtn2 v15.8h, v12.4s\n" + ".inst 0x4ea16b0e // bfcvtn2 v14.8h, v24.4s\n" + "str q9, [x27, #0x10]\n" + ".inst 0x4ea16a57 // bfcvtn2 v23.8h, v18.4s\n" + ".inst 0x4ea16a11 // bfcvtn2 v17.8h, v16.4s\n" + "str q6, [x27, #0x20]\n" + "str q2, [x27, #0x30]\n" + "str q29, [x27, #0x40]\n" + "str q20, [x27, #0x50]\n" + "str q22, [x27, #0x60]\n" + "str q10, [x27, #0x70]\n" + "str q15, [x27, #0x80]\n" + "str q14, [x27, #0x90]\n" + "str q23, [x27, #0xa0]\n" + "str q17, [x27, #0xb0]\n" + "add x27, x27, %x[out_stride]\n" + "bge 5b\n" + "6:" // Main row loop: Column loop skip + "cbz x28, 11f\n" + "cmp x28, #0x4\n" + "movi v16.16b, #0x0\n" + "str q16, [x27, #0x0]\n" + "str q16, [x27, #0x10]\n" + "str q16, [x27, #0x20]\n" + "str q16, [x27, #0x30]\n" + "str q16, [x27, #0x40]\n" + "str q16, [x27, #0x50]\n" + "str q16, [x27, #0x60]\n" + "str q16, [x27, #0x70]\n" + "str q16, [x27, #0x80]\n" + "str q16, [x27, #0x90]\n" + "str q16, [x27, #0xa0]\n" + "str q16, [x27, #0xb0]\n" + "blt 8f\n" + "7:" // Main row loop: width 4 loop: loop + "ldr q25, [x9], #0x10\n" + "ldr q24, [x26], #0x10\n" + "sub x28, x28, #0x4\n" + "ldr q21, [x25], #0x10\n" + "ldr q20, [x24], #0x10\n" + "cmp x28, #0x4\n" + "ldr q23, [x23], #0x10\n" + "ldr q19, [x22], #0x10\n" + "ldr q18, [x21], #0x10\n" + "ldr q17, [x20], #0x10\n" + "zip1 v22.4s, v25.4s, v21.4s\n" + "zip1 v16.4s, v24.4s, v20.4s\n" + "zip2 v21.4s, v25.4s, v21.4s\n" + "zip2 v20.4s, v24.4s, v20.4s\n" + "zip1 v27.4s, v23.4s, v18.4s\n" + "zip1 v26.4s, v19.4s, v17.4s\n" + "zip2 v25.4s, v23.4s, v18.4s\n" + "zip2 v24.4s, v19.4s, v17.4s\n" + "zip1 v19.4s, v22.4s, v16.4s\n" + "zip1 v18.4s, v21.4s, v20.4s\n" + "zip1 v17.4s, v27.4s, v26.4s\n" + "zip2 v23.4s, v22.4s, v16.4s\n" + "zip1 v16.4s, v25.4s, v24.4s\n" + "zip2 v22.4s, v21.4s, v20.4s\n" + ".inst 0x0ea16a75 // bfcvtn v21.4h, v19.4s\n" + ".inst 0x0ea16a54 // bfcvtn v20.4h, v18.4s\n" + ".inst 0x0ea16a33 // bfcvtn v19.4h, v17.4s\n" + "zip2 v18.4s, v27.4s, v26.4s\n" + ".inst 0x0ea16a11 // bfcvtn v17.4h, v16.4s\n" + "zip2 v16.4s, v25.4s, v24.4s\n" + ".inst 0x4ea16af5 // bfcvtn2 v21.8h, v23.4s\n" + ".inst 0x4ea16ad4 // bfcvtn2 v20.8h, v22.4s\n" + ".inst 0x4ea16a53 // bfcvtn2 v19.8h, v18.4s\n" + ".inst 0x4ea16a11 // bfcvtn2 v17.8h, v16.4s\n" + "str q21, [x27, #0x0]\n" + "str q20, [x27, #0x10]\n" + "str q19, [x27, #0x60]\n" + "str q17, [x27, #0x70]\n" + "add x27, x27, #0x20\n" + "bge 7b\n" + "8:" // Main row loop: width 4 loop: skip + "cmp x28, #0x1\n" + "blt 10f\n" + "9:" // Main row loop: width 1 loop: loop + "ldr s23, [x9], #0x4\n" + "ldr s22, [x26], #0x4\n" + "sub x28, x28, #0x1\n" + "ldr s19, [x25], #0x4\n" + "ldr s17, [x24], #0x4\n" + "cmp x28, #0x1\n" + "ldr s21, [x23], #0x4\n" + "ldr s20, [x22], #0x4\n" + "ldr s18, [x21], #0x4\n" + "ldr s16, [x20], #0x4\n" + "zip1 v19.4s, v23.4s, v19.4s\n" + "zip1 v17.4s, v22.4s, v17.4s\n" + "zip1 v18.4s, v21.4s, v18.4s\n" + "zip1 v16.4s, v20.4s, v16.4s\n" + "zip1 v17.4s, v19.4s, v17.4s\n" + "zip1 v16.4s, v18.4s, v16.4s\n" + ".inst 0x0ea16a31 // bfcvtn v17.4h, v17.4s\n" + ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n" + "str d17, [x27, #0x0]\n" + "str d16, [x27, #0x60]\n" + "add x27, x27, #0x8\n" + "bge 9b\n" + "10:" // Main row loop: width 1 loop: skip + "11:" // Main row loop: odd col skip + "cmp %x[height], #0x8\n" + "add %x[out], %x[out], #0xc0\n" + "bge 4b\n" + "cbz %x[height], 21f\n" + "12:" // Main loop skip + "13:" // Tail row loop: Head + "mov x9, %x[in]\n" + "mov x20, %x[width]\n" + "cmp %x[height], #0x3\n" + "mov x27, %x[out]\n" + "add x26, x9, %x[in_stride]\n" + "add x25, x26, %x[in_stride]\n" + "add x24, x25, %x[in_stride]\n" + "csel x25, x25, %x[pad_row], GE\n" + "add %x[in], x24, %x[in_stride]\n" + "csel x24, x24, %x[pad_row], GT\n" + "cmp %x[height], #0x1\n" + "sub %x[height], %x[height], #0x4\n" + "csel x26, x26, %x[pad_row], GT\n" + "cmp x20, #0xc\n" + "blt 15f\n" + "14:" // Tail row loop: Column loop + "ldr q24, [x9], #0x10\n" + "ldr q23, [x26], #0x10\n" + "sub x20, x20, #0xc\n" + "ldr q22, [x25], #0x10\n" + "ldr q16, [x24], #0x10\n" + "cmp x20, #0xc\n" + "ldr q28, [x9], #0x10\n" + "ldr q27, [x26], #0x10\n" + "ldr q21, [x25], #0x10\n" + "ldr q20, [x24], #0x10\n" + "ldr q19, [x9], #0x10\n" + "zip1 v26.4s, v24.4s, v22.4s\n" + "zip1 v25.4s, v23.4s, v16.4s\n" + "ldr q18, [x26], #0x10\n" + "ldr q17, [x25], #0x10\n" + "zip2 v24.4s, v24.4s, v22.4s\n" + "zip2 v23.4s, v23.4s, v16.4s\n" + "ldr q16, [x24], #0x10\n" + "zip1 v2.4s, v28.4s, v21.4s\n" + "zip1 v22.4s, v27.4s, v20.4s\n" + "zip2 v1.4s, v28.4s, v21.4s\n" + "zip2 v0.4s, v27.4s, v20.4s\n" + "zip1 v31.4s, v19.4s, v17.4s\n" + "zip1 v30.4s, v18.4s, v16.4s\n" + "zip2 v29.4s, v19.4s, v17.4s\n" + "zip2 v28.4s, v18.4s, v16.4s\n" + "zip1 v21.4s, v26.4s, v25.4s\n" + "zip1 v20.4s, v24.4s, v23.4s\n" + "zip1 v19.4s, v2.4s, v22.4s\n" + "zip1 v18.4s, v1.4s, v0.4s\n" + "zip1 v17.4s, v31.4s, v30.4s\n" + "zip1 v16.4s, v29.4s, v28.4s\n" + ".inst 0x0ea16abb // bfcvtn v27.4h, v21.4s\n" + "zip2 v26.4s, v26.4s, v25.4s\n" + ".inst 0x0ea16a99 // bfcvtn v25.4h, v20.4s\n" + "zip2 v24.4s, v24.4s, v23.4s\n" + ".inst 0x0ea16a77 // bfcvtn v23.4h, v19.4s\n" + "zip2 v22.4s, v2.4s, v22.4s\n" + ".inst 0x0ea16a55 // bfcvtn v21.4h, v18.4s\n" + "zip2 v20.4s, v1.4s, v0.4s\n" + ".inst 0x0ea16a33 // bfcvtn v19.4h, v17.4s\n" + "zip2 v18.4s, v31.4s, v30.4s\n" + ".inst 0x0ea16a11 // bfcvtn v17.4h, v16.4s\n" + "zip2 v16.4s, v29.4s, v28.4s\n" + ".inst 0x4ea16b5b // bfcvtn2 v27.8h, v26.4s\n" + ".inst 0x4ea16b19 // bfcvtn2 v25.8h, v24.4s\n" + ".inst 0x4ea16ad7 // bfcvtn2 v23.8h, v22.4s\n" + ".inst 0x4ea16a95 // bfcvtn2 v21.8h, v20.4s\n" + ".inst 0x4ea16a53 // bfcvtn2 v19.8h, v18.4s\n" + ".inst 0x4ea16a11 // bfcvtn2 v17.8h, v16.4s\n" + "str q27, [x27, #0x0]\n" + "str q25, [x27, #0x10]\n" + "str q23, [x27, #0x20]\n" + "str q21, [x27, #0x30]\n" + "str q19, [x27, #0x40]\n" + "str q17, [x27, #0x50]\n" + "add x27, x27, %x[out_stride]\n" + "bge 14b\n" + "15:" // Tail row loop: Column loop skip + "cbz x20, 20f\n" + "cmp x20, #0x4\n" + "movi v16.16b, #0x0\n" + "str q16, [x27, #0x0]\n" + "str q16, [x27, #0x10]\n" + "str q16, [x27, #0x20]\n" + "str q16, [x27, #0x30]\n" + "str q16, [x27, #0x40]\n" + "str q16, [x27, #0x50]\n" + "blt 17f\n" + "16:" // Tail row loop: width 4 loop: loop + "ldr q21, [x9], #0x10\n" + "ldr q20, [x26], #0x10\n" + "sub x20, x20, #0x4\n" + "ldr q19, [x25], #0x10\n" + "ldr q17, [x24], #0x10\n" + "cmp x20, #0x4\n" + "zip1 v18.4s, v21.4s, v19.4s\n" + "zip1 v16.4s, v20.4s, v17.4s\n" + "zip2 v21.4s, v21.4s, v19.4s\n" + "zip2 v20.4s, v20.4s, v17.4s\n" + "zip1 v17.4s, v18.4s, v16.4s\n" + "zip2 v19.4s, v18.4s, v16.4s\n" + "zip1 v16.4s, v21.4s, v20.4s\n" + ".inst 0x0ea16a32 // bfcvtn v18.4h, v17.4s\n" + "zip2 v17.4s, v21.4s, v20.4s\n" + ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n" + ".inst 0x4ea16a72 // bfcvtn2 v18.8h, v19.4s\n" + ".inst 0x4ea16a30 // bfcvtn2 v16.8h, v17.4s\n" + "str q18, [x27, #0x0]\n" + "str q16, [x27, #0x10]\n" + "add x27, x27, #0x20\n" + "bge 16b\n" + "17:" // Tail row loop: width 4 loop: skip + "cmp x20, #0x1\n" + "blt 19f\n" + "18:" // Tail row loop: width 1 loop: loop + "ldr s19, [x9], #0x4\n" + "ldr s18, [x26], #0x4\n" + "sub x20, x20, #0x1\n" + "ldr s17, [x25], #0x4\n" + "ldr s16, [x24], #0x4\n" + "cmp x20, #0x1\n" + "zip1 v17.4s, v19.4s, v17.4s\n" + "zip1 v16.4s, v18.4s, v16.4s\n" + "zip1 v16.4s, v17.4s, v16.4s\n" + ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n" + "str d16, [x27, #0x0]\n" + "add x27, x27, #0x8\n" + "bge 18b\n" + "19:" // Tail row loop: width 1 loop: skip + "20:" // Tail row loop: odd col skip + "cmp %x[height], #0x1\n" + "add %x[out], %x[out], #0x60\n" + "bge 13b\n" + "21:" // Done + : [bias] "+&r"(bias_ptr), [height] "+&r"(height), [in] "+&r"(in), [out] "+&r"(out) + : [bias_step] "r"(bias_step), [in_stride] "r"(in_stride), [out_stride] "r"(out_stride), [pad_row] "r"(pad_row), + [width] "r"(width) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31", "x9", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.h b/kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.h new file mode 100644 index 0000000000000000000000000000000000000000..f786c7a7c92865d84184ee079072250ff3c5ce08 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.h @@ -0,0 +1,84 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Gets n step value. +/// +/// The starting row index must be divisible by `n_step`. +/// +/// @return The n step value. +size_t kai_get_n_step_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon(void); + +/// Gets the offset in bytes to the data element in the RHS matrix buffer. +/// +/// @param[in] n_idx Column index. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon(size_t n_idx); + +/// Gets the offset in bytes to the data element in the bias buffer. +/// +/// @param[in] n_idx Column index. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_bias_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon(size_t n_idx); + +/// Gets the offset in bytes to the data element in the packed RHS buffer. +/// +/// @param[in] n_idx Row index. +/// @param[in] k Number of columns. +/// @param[in] nr Block size in N dimension. +/// @param[in] kr Block size in K dimension. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_packed_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon(size_t n_idx, size_t k, size_t nr, size_t kr); + +/// Gets the size in bytes of the packed RHS buffer. +/// +/// @param[in] n Number of rows. +/// @param[in] k Number of columns. +/// @param[in] nr Block size in N dimension. +/// @param[in] kr Block size in K dimension. +/// +/// @return The size in bytes of the packed RHS buffer. +size_t kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon(size_t n, size_t k, size_t nr, size_t kr); + +/// Runs the RHS packing function for matrix multiplication. +/// +/// The pointer of each buffers (RHS, bias and packed RHS) needs to be added with offset +/// calculated using the following functions: +/// +/// * RHS: @ref kai_get_rhs_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon. +/// * Bias: @ref kai_get_bias_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon. +/// * Output: @ref kai_get_rhs_packed_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon. +/// +/// @param[in] num_groups Number of groups. It must be 1. +/// @param[in] n Number of columns of the output matrix. +/// @param[in] k Common dimension between the LHS and RHS matrix. +/// @param[in] nr Block size in N dimension. +/// @param[in] kr Block size in K dimension. +/// @param[in] sr Number of kr splits. It must be 1. +/// @param[in] rhs_stride Row stride in bytes of the RHS matrix. +/// @param[in] rhs RHS matrix data buffer. +/// @param[in] bias Bias matrix data buffer. +/// @param[in] scale Scale data buffer. It must be NULL. +/// @param[out] rhs_packed Packed RHS matrix. +/// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. It must be 0. +/// @param[in] params Extra packing parameters. It must be NULL. +void kai_run_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, + const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/test/common/bfloat16.hpp b/test/common/bfloat16.hpp index ca0e0b371f6bad9f6b00ca55fd6a3e09c331171a..9291918a638499b94545e6b9ff70c30310c46874 100644 --- a/test/common/bfloat16.hpp +++ b/test/common/bfloat16.hpp @@ -82,6 +82,14 @@ public: return _data != rhs._data; } + uint16_t data() const { + return _data; + } + + void set_data(uint16_t data) { + _data = data; + } + /// Writes the value to the output stream. /// /// @param[in] os Output stream to be written to. diff --git a/test/common/compare.cpp b/test/common/compare.cpp index 54af776fd6937a5dd4deaa3469977a7ec1db3a83..b000f3f5efaa7e06295fd4539e7d05b021806008 100644 --- a/test/common/compare.cpp +++ b/test/common/compare.cpp @@ -213,6 +213,9 @@ bool compare( case DataType::FP16: return compare_raw(imp_data, ref_data, format, full_height, full_width, rect, handler); + case DataType::BF16: + return compare_raw(imp_data, ref_data, format, full_height, full_width, rect, handler); + default: break; } diff --git a/test/common/matmul_test_common.cpp b/test/common/matmul_test_common.cpp new file mode 100644 index 0000000000000000000000000000000000000000..73d41c09e28cab56db5c95df6e78e1e0c757c319 --- /dev/null +++ b/test/common/matmul_test_common.cpp @@ -0,0 +1,24 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "matmul_test_common.hpp" + +#include + +namespace kai::test { +void PrintTo(const MatMulTestParams& param, std::ostream* os) { + const auto& [method, shape, portion] = param; + + // NOLINTBEGIN(cppcoreguidelines-pro-bounds-constant-array-index) + *os << "Method_" << method.name // + << "__M_" << shape.m << "__N_" << shape.n << "__K_" << shape.k // + << "__PortionStartRow_" << static_cast(portion.start_row() * 1000) // + << "__PortionStartCol_" << static_cast(portion.start_col() * 1000) // + << "__PortionHeight_" << static_cast(portion.height() * 1000) // + << "__PortionWidth_" << static_cast(portion.width() * 1000); + // NOLINTEND(cppcoreguidelines-pro-bounds-constant-array-index) +} +} // namespace kai::test diff --git a/test/common/matmul_test_common.hpp b/test/common/matmul_test_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..21f6e244c434d6789321c69ddf5b4a2471eec3d2 --- /dev/null +++ b/test/common/matmul_test_common.hpp @@ -0,0 +1,359 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include +#include +#include +#include +#include + +#include "kai/kai_common.h" +#include "test/common/data_format.hpp" +#include "test/common/float16.hpp" +#include "test/common/matrix_portion.hpp" + +namespace kai::test { +/// Matrix multiplication shape. +struct MatMulShape { + size_t m; ///< LHS height. + size_t n; ///< RHS width. + size_t k; ///< LHS width and RHS height. +}; + +// NOLINTBEGIN(misc-non-private-member-variables-in-classes) + +/// Matrix multiplication method. +struct MatMulMethod { + std::string_view name; ///< Name of matmul method. + + size_t m0{0}; ///< Block size in M dimension. + size_t n0{0}; ///< Block size in N dimension. + size_t k0{0}; ///< Block size in K dimension. + + bool lhs_transposed; ///< LHS matrix is transposed. + bool rhs_transposed; ///< RHS matrix is transposed. + + DataFormat dst_format; ///< Data format of the destination matrix. + DataFormat lhs_format; ///< Data format of the LHS matrix. + DataFormat packed_lhs_format; ///< Data format of the packed LHS matrix. + DataFormat rhs_format; ///< Data format of the RHS matrix. + DataFormat packed_rhs_format; ///< Data format of the packed RHS matrix. + DataFormat bias_format; ///< Data format of the bias vector. + + /// Check if CPU supports required features. + /// + /// @return Supported (true) or not supported (false). + std::function fn_is_supported; + + /// Gets mr value. + /// + /// This is the packing parameter which must be used to pack the LHS matrix (if necessary). + /// + /// @return The mr value. + std::function fn_get_mr; + + /// Gets nr value. + /// + /// This is the packing parameter which must be used to pack the RHS matrix (if necessary). + /// + /// @return The nr value. + std::function fn_get_nr; + + /// Gets kr value. + /// + /// This is the packing parameter which must be used to pack the LHS and RHS matrix (if necessary). + /// + /// @return The kr value. + std::function fn_get_kr; + + /// Gets sr value. + /// + /// This is the packing parameter which must be used to pack the RHS matrix. + /// + /// @return The sr value. + std::function fn_get_sr; + + /// Gets m step value for main kernel. + /// + /// The starting row index must be divisible by `m_step`. + /// + /// @return The m step value. + std::function fn_get_main_m_step; + + /// Gets n step value for RHS packing kernel. + /// + /// The starting row index must be divisible by `n_step`. + /// + /// @return The n step value. + std::function fn_get_pack_rhs_n_step; + + /// Gets n step value for main kernel. + /// + /// The starting column index must be divisible by `n_step`. + /// + /// @return The n step value. + std::function fn_get_main_n_step; + + /// Gets the offset in bytes of the LHS matrix. + /// + /// @param[in] m_idx Coordinate of the matrix in M dimension. + /// @param[in] stride Row stride in bytes. + /// + /// @return The offset in bytes. + std::function fn_get_lhs_offset; + + /// Gets the size in bytes of the packed LHS matrix. + /// + /// @param[in] m Number of rows in the unpacked LHS matrix. + /// @param[in] k Number of columns in the unpacked LHS matrix. + /// @param[in] mr Number of rows to be interleaved. + /// @param[in] kr Unused. Must be 1. + /// @param[in] sr Unused. Must be 1. + /// + /// @return The size in bytes. + std::function fn_get_packed_lhs_size; + + /// Gets the offset in bytes of the packed LHS matrix. + /// + /// @param[in] m_idx Coordinate of the matrix in M dimension. + /// @param[in] k Size of the matrix in K dimension. + /// + /// @return The offset in bytes. + std::function fn_get_packed_lhs_offset; + + /// Preprocesses the LHS matrix. + /// + /// @param[in] m Number of rows of the unpacked LHS matrix. + /// @param[in] k Common dimension between the LHS and RHS matrix. + /// @param[in] mr Block size in M dimension. It must be {{ kernel.interleave_by }}VL. + /// @param[in] kr Block size in K dimension. It must be {{ kernel.block_by }}. + /// @param[in] sr Number of kr splits. It must be 1. + /// @param[in] m_idx_start Unused. Must be 0. + /// @param[in] lhs LHS matrix data buffer. + /// @param[in] lhs_stride Row stride in bytes of the LHS matrix. + /// @param[out] lhs_packed Packed RHS matrix. + std::function + fn_pack_lhs; + + /// Gets a value indicating whether LHS packing is needed. + [[nodiscard]] bool is_pack_lhs_needed() const { + return fn_pack_lhs != nullptr; + } + + /// Gets the offset in bytes of the RHS matrix. + /// + /// @param[in] n_idx Coordinate of the matrix in N dimension. + /// + /// @return The offset in bytes. + std::function fn_get_rhs_offset; + + /// Gets the size in bytes of the packed RHS matrix. + /// + /// @param[in] n Size of the matrix in N dimension. + /// @param[in] k Size of the matrix in K dimension. + /// + /// @return The size in bytes. + std::function fn_get_packed_rhs_size; + + /// Gets the size in bytes of the packed RHS matrix. + /// + /// @param[in] n Size of the matrix in N dimension. + /// @param[in] k Size of the matrix in K dimension. + /// @param[in] nr Block size in N dimension. + /// @param[in] kr Block size in K dimension. + /// + /// @return The size in bytes. + std::function fn_get_packed_rhs_size_generic_block_size = nullptr; + + /// Gets the offset in bytes of the packed RHS matrix in the RHS packing kernel + /// + /// @param[in] n_idx Coordinate of the matrix in N dimension. + /// @param[in] k Size of the matrix in K dimension. + /// + /// @return The offset in bytes. + std::function fn_get_pack_rhs_packed_rhs_offset; + + /// Gets the offset in bytes of the packed RHS matrix in the main kernel. + /// + /// @param[in] n_idx Coordinate of the matrix in N dimension. + /// @param[in] k Size of the matrix in K dimension. + /// + /// @return The offset in bytes. + std::function fn_get_main_packed_rhs_offset; + + std::function + fn_pack_rhs; + + /// Gets the offset in bytes to the data element in the bias buffer. + /// + /// @param[in] n_idx Column index. + /// + /// @return The offset in bytes to the data element. + std::function fn_get_bias_offset; + + /// Gets the offset in bytes to the data element in the destination matrix buffer. + /// + /// @param[in] m_idx Row index. + /// @param[in] n_idx Column index. + /// @param[in] stride Row stride in bytes. + /// + /// @return The offset in bytes to the data element. + std::function fn_get_dst_offset; + + /// Gets the size in bytes of the destination matrix buffer. + /// + /// @param[in] m Number of rows. + /// @param[in] n Number of columns. + /// + /// @return The size in bytes of the destination matrix buffer. + std::function fn_get_dst_size; + + /// Performs F16 or F32 matrix multiplication with RHS packing + /// followed by clamp operation. + /// + /// @param[in] m Size of the matrix in M dimension. + /// @param[in] n Size of the matrix in N dimension. + /// @param[in] k Size of the matrix in K dimension. + /// @param[in] lhs LHS data buffer. + /// @param[in] packed_rhs Packed RHS data buffer. + /// @param[out] dst Output data buffer. + /// @param[in] lhs_stride LHS row stride. + /// @param[in] dst_stride_row Output row stride. + /// @param[in] dst_stride_col Output column stride. + /// @param[in] clamp_min Lower bound of the output data. + /// @param[in] clamp_max Upper bound of the output data. + std::function + fn_matmul_f16_f16_f16p = nullptr; + + std::function + fn_matmul_f32_f32_f32p = nullptr; + + /// Performs BF16 matrix multiplication with LHS and RHS packing + /// followed by clamp operation. + /// + /// @param[in] m Size of the matrix in M dimension. + /// @param[in] n Size of the matrix in N dimension. + /// @param[in] k Size of the matrix in K dimension. + /// @param[in] packed_lhs Packed LHS data buffer. + /// @param[in] packed_rhs Packed RHS data buffer. + /// @param[out] dst Output data buffer. + /// @param[in] dst_stride_row Output row stride. + /// @param[in] dst_stride_col Output column stride. + /// @param[in] clamp_min Lower bound of the output data. + /// @param[in] clamp_max Upper bound of the output data. + std::function + fn_matmul_f32_bf16p_bf16p = nullptr; + + /// Performs F32 matrix multiplication with LHS & RHS packing + /// followed by clamp operation. + /// + /// @param[in] m Number of output rows to be computed. + /// @param[in] n Number of output columns to be computed. + /// @param[in] k Common dimension of the LHS and RHS operands. + /// @param[in] packed_lhs Packed LHS matrix buffer. + /// @param[in] packed_rhs Packed RHS matrix buffer. + /// @param[out] dst Output matrix buffer. + /// @param[in] dst_stride_row Row stride in bytes of the output matrix. + /// @param[in] dst_stride_col Column stride in bytes of the output matrix. + /// @param[in] clamp_min Minimum value to clamp the final result. + /// @param[in] clamp_max Maximum value to clamp the final result. + std::function + fn_matmul_f32_f32p_f32p = nullptr; + + /// Gets a value indicating whether pre-processing the RHS matrix is needed. + [[nodiscard]] bool is_pack_rhs_needed() const { + return fn_pack_rhs != nullptr; + } + + /// Preprocesses the RHS matrix. + /// + /// @param[in] n Size of the matrix in N dimension. + /// @param[in] k Size of the matrix in K dimension. + /// @param[in] rhs RHS data buffer. + /// @param[in] rhs_row_stride RHS row stride. + /// @param[in] bias Bias data buffer. + /// @param[in] scale Quantization scales data buffer. + /// @param[out] packed_rhs Packed RHS data buffer. + void pack_rhs( + size_t n, size_t k, const void* rhs, size_t rhs_row_stride, const void* bias, const void* scale, + void* packed_rhs) const { + KAI_UNUSED(n); + KAI_UNUSED(k); + KAI_UNUSED(rhs); + KAI_UNUSED(rhs_row_stride); + KAI_UNUSED(bias); + KAI_UNUSED(scale); + + if (fn_pack_rhs != nullptr) { + fn_pack_rhs( + 1, n, k, fn_get_nr(), fn_get_kr(), fn_get_sr(), rhs_row_stride, rhs, bias, nullptr, packed_rhs, 0, + nullptr); + } else { + KAI_ERROR("RHS pre-processing is not supported!"); + } + } + + [[nodiscard]] bool has_main_kernel() const { + return fn_matmul_f16_f16_f16p != nullptr || fn_matmul_f32_f32p_f32p != nullptr || + fn_matmul_f32_f32_f32p != nullptr || fn_matmul_f32_bf16p_bf16p != nullptr; + } + + void main_kernel( + size_t m, size_t n, size_t k, const void* lhs, const void* rhs, const void* bias, void* dst, size_t lhs_stride, + size_t rhs_stride, size_t dst_stride, float clamp_min, float clamp_max) const { + KAI_UNUSED(bias); + KAI_UNUSED(rhs_stride); + + if (fn_matmul_f16_f16_f16p) { + fn_matmul_f16_f16_f16p( + m, n, k, lhs, lhs_stride, rhs, dst, dst_stride, sizeof(Float16), clamp_min, + static_cast(clamp_max)); + } else if (fn_matmul_f32_f32_f32p) { + fn_matmul_f32_f32_f32p( + m, n, k, lhs, lhs_stride, rhs, dst, dst_stride, sizeof(float), clamp_min, + static_cast(clamp_max)); + } else if (fn_matmul_f32_f32p_f32p) { + fn_matmul_f32_f32p_f32p(m, n, k, lhs, rhs, dst, dst_stride, sizeof(float), clamp_min, clamp_max); + } else if (fn_matmul_f32_bf16p_bf16p) { + fn_matmul_f32_bf16p_bf16p( + m, n, k, reinterpret_cast(lhs), rhs, reinterpret_cast(dst), dst_stride, + sizeof(float), clamp_min, clamp_max); + } else { + KAI_ERROR("Main kernel is not available!"); + } + } +}; + +// NOLINTEND(misc-non-private-member-variables-in-classes) + +/// Matrix multiplication test information. +using MatMulTestParams = std::tuple; + +/// Prints the test information. +void PrintTo(const MatMulTestParams& param, std::ostream* os); +} // namespace kai::test diff --git a/test/common/memory.hpp b/test/common/memory.hpp index bf5fbb017ed3d4cca5895a895c795c08ff30721f..c856218f6351eefa6c30fdf715ec6f875f756037 100644 --- a/test/common/memory.hpp +++ b/test/common/memory.hpp @@ -7,8 +7,11 @@ #pragma once #include +#include #include +#include "kai/kai_common.h" +#include "test/common/bfloat16.hpp" #include "test/common/int4.hpp" namespace kai::test { @@ -39,6 +42,9 @@ T read_array(const void* array, size_t index) { } else if constexpr (std::is_same_v) { const auto [lo, hi] = Int4::unpack_u8(reinterpret_cast(array)[index / 2]); return index % 2 == 0 ? lo : hi; + } else if constexpr (std::is_same_v) { + uint16_t raw_value = reinterpret_cast(array)[index]; + return BFloat16(kai_cast_f32_bf16(raw_value)); } else { return reinterpret_cast(array)[index]; } diff --git a/test/reference/pack.cpp b/test/reference/pack.cpp index 221ba36079ec8d6c6bf59289826a37ec2525b300..ad123762d0852f007c943674e571a6b5a3414a7e 100644 --- a/test/reference/pack.cpp +++ b/test/reference/pack.cpp @@ -10,29 +10,31 @@ #include #include #include -#include #include #include "kai/kai_common.h" +#include "test/common/bfloat16.hpp" #include "test/common/data_format.hpp" #include "test/common/data_type.hpp" -#include "test/common/float16.hpp" #include "test/common/memory.hpp" #include "test/common/round.hpp" -#include "test/reference/quantize.hpp" namespace kai::test { namespace { +uint16_t convert(const uint8_t* src_ptr_elm, DataType src_dtype, DataType dst_dtype) { + KAI_ASSUME(src_dtype == DataType::FP32 && dst_dtype == DataType::BF16); + return BFloat16(*reinterpret_cast(src_ptr_elm)).data(); +} + std::vector pack_block( - const void* src, size_t data_esize, size_t full_height, size_t full_width, size_t block_height, size_t block_width, - size_t subblock_height, size_t subblock_width) { + const void* src, DataType src_dtype, DataType dst_dtype, size_t src_esize, size_t dst_esize, size_t full_height, + size_t full_width, size_t block_height, size_t block_width, size_t subblock_height, size_t subblock_width) { const auto dst_bytes = - round_up_multiple(full_height, block_height) * round_up_multiple(full_width, block_width) * data_esize; + round_up_multiple(full_height, block_height) * round_up_multiple(full_width, block_width) * dst_esize; - std::vector dst; - dst.resize(dst_bytes); + std::vector dst(dst_bytes, 0); const auto* src_ptr = reinterpret_cast(src); auto* dst_ptr = dst.data(); @@ -42,18 +44,38 @@ std::vector pack_block( for (size_t y_subblock = 0; y_subblock < block_height; y_subblock += subblock_height) { for (size_t x_subblock = 0; x_subblock < block_width; x_subblock += subblock_width) { for (size_t y_element = 0; y_element < subblock_height; ++y_element) { - if (y_block + y_subblock + y_element < full_height) { - const auto len = std::min(subblock_width, full_width - x_block - x_subblock); - - memcpy( - dst_ptr, - src_ptr + - ((y_block + y_subblock + y_element) * full_width + x_block + x_subblock) * - data_esize, - len * data_esize); - } + if (src_dtype == dst_dtype) { + const size_t esize = dst_esize; + + if (y_block + y_subblock + y_element < full_height) { + const auto len = std::min(subblock_width, full_width - x_block - x_subblock); + + memcpy( + dst_ptr, + src_ptr + + ((y_block + y_subblock + y_element) * full_width + x_block + x_subblock) * + esize, + len * esize); + } - dst_ptr += subblock_width * data_esize; + dst_ptr += subblock_width * esize; + } else if (dst_esize == 2 /* 16 bits */) { + for (size_t x_element = 0; x_element < subblock_width; ++x_element) { + if (y_block + y_subblock + y_element < full_height) { + if (x_block + x_subblock + x_element < full_width) { + const uint8_t* src_ptr_elm = src_ptr + + ((y_block + y_subblock + y_element) * full_width + x_block + x_subblock + + x_element) * + src_esize; + + uint16_t src_value = convert(src_ptr_elm, src_dtype, dst_dtype); + memcpy(dst_ptr, &src_value, dst_esize); + } + } + + dst_ptr += dst_esize; + } + } } } } @@ -67,43 +89,65 @@ std::vector pack_block( /// Packs the matrix from raw to per-row bias format. std::vector pack_bias_per_row( - size_t data_esize, size_t zero_point_esize, const void* src, const void* bias, size_t height, size_t width, - size_t block_height, size_t block_width, size_t subblock_height, size_t subblock_width) { + DataType src_dtype, DataType bias_dtype, DataType dst_dtype, size_t src_esize, size_t bias_esize, size_t dst_esize, + const void* src, const void* bias, size_t height, size_t width, size_t block_height, size_t block_width, + size_t subblock_height, size_t subblock_width) { + KAI_ASSUME(src_dtype == bias_dtype); + const auto num_groups = (height + block_height - 1) / block_height; const auto group_num_blocks = (width + block_width - 1) / block_width; - - const auto group_zero_points_bytes = block_height * zero_point_esize; - const auto block_data_bytes = block_height * block_width * data_esize; - const auto group_bytes = group_zero_points_bytes + group_num_blocks * block_data_bytes; + const auto group_bias_bytes = block_height * bias_esize; + const auto block_data_bytes = block_height * block_width * dst_esize; + const auto group_bytes = group_bias_bytes + group_num_blocks * block_data_bytes; const auto dst_bytes = num_groups * group_bytes; - std::vector dst; - dst.resize(dst_bytes); + std::vector dst(dst_bytes, 0); const auto* src_ptr = reinterpret_cast(src); const auto* bias_ptr = reinterpret_cast(bias); auto* dst_ptr = dst.data(); for (size_t y_block = 0; y_block < height; y_block += block_height) { - // Packs the zero points. + // Packs the bias. const auto bias_len = std::min(block_height, height - y_block); - memcpy(dst_ptr, bias_ptr, bias_len * zero_point_esize); - bias_ptr += block_height * zero_point_esize; - dst_ptr += block_height * zero_point_esize; + memcpy(dst_ptr, bias_ptr, bias_len * bias_esize); + bias_ptr += block_height * bias_esize; + dst_ptr += block_height * bias_esize; for (size_t x_block = 0; x_block < width; x_block += block_width) { for (size_t y_subblock = 0; y_subblock < block_height; y_subblock += subblock_height) { for (size_t x_subblock = 0; x_subblock < block_width; x_subblock += subblock_width) { for (size_t y_element = 0; y_element < subblock_height; ++y_element) { - if (y_block + y_subblock + y_element < height) { - const auto len = std::min(subblock_width, width - x_block - x_subblock); - memcpy( - dst_ptr, - src_ptr + - ((y_block + y_subblock + y_element) * width + x_block + x_subblock) * data_esize, - len * data_esize); + if (src_dtype == dst_dtype) { + const size_t esize = dst_esize; + if (y_block + y_subblock + y_element < height) { + const auto len = std::min(subblock_width, width - x_block - x_subblock); + + memcpy( + dst_ptr, + src_ptr + + ((y_block + y_subblock + y_element) * width + x_block + x_subblock) * esize, + len * esize); + } + + dst_ptr += subblock_width * esize; + } else if (dst_esize == 2 /* 16 bits */) { + for (size_t x_element = 0; x_element < subblock_width; ++x_element) { + if (y_block + y_subblock + y_element < height) { + if (x_block + x_subblock + x_element < width) { + const uint8_t* src_ptr_elm = src_ptr + + ((y_block + y_subblock + y_element) * width + x_block + x_subblock + + x_element) * + src_esize; + + const uint16_t dst_value = convert(src_ptr_elm, src_dtype, dst_dtype); + memcpy(dst_ptr, &dst_value, dst_esize); + } + } + + dst_ptr += dst_esize; + } } - dst_ptr += subblock_width * data_esize; } } } @@ -118,7 +162,7 @@ std::vector pack_bias_per_row( } // namespace std::vector pack( - const DataFormat& dst_format, const void* src, [[maybe_unused]] const void* scales, const void* zero_points, + const DataFormat& dst_format, const void* src, [[maybe_unused]] const void* scales, const void* bias, const DataFormat& src_format, size_t height, size_t width) { const auto dst_dt = dst_format.data_type(); const auto dst_qf = dst_format.pack_format(); @@ -131,27 +175,31 @@ std::vector pack( const auto subblock_width = dst_format.actual_subblock_width(width); if (src_qf == DataFormat::PackFormat::NONE && dst_qf == DataFormat::PackFormat::BIAS_PER_ROW) { - KAI_ASSUME(src_dt == dst_dt); + KAI_ASSUME((src_dt == dst_dt) || (src_dt == DataType::FP32 && dst_dt == DataType::BF16)); - const auto data_esize = data_type_size_in_bits(dst_dt); - const auto zero_point_esize = data_type_size_in_bits(dst_format.zero_point_data_type()); + const auto src_esize = data_type_size_in_bits(src_dt); + const auto dst_esize = data_type_size_in_bits(dst_dt); + const auto bias_esize = data_type_size_in_bits(dst_format.zero_point_data_type()); + const auto bias_dt = dst_format.zero_point_data_type(); - if (data_esize % 8 == 0 && zero_point_esize % 8 == 0) { - return pack_bias_per_row( - data_esize / 8, zero_point_esize / 8, src, zero_points, height, width, block_height, block_width, - subblock_height, subblock_width); - } + KAI_ASSUME(dst_esize % 8 == 0 && bias_esize % 8 == 0 && src_esize % 8 == 0); + + return pack_bias_per_row( + src_dt, bias_dt, dst_dt, src_esize / 8, bias_esize / 8, dst_esize / 8, src, bias, height, width, + block_height, block_width, subblock_height, subblock_width); } if (src_qf == DataFormat::PackFormat::NONE && dst_qf == DataFormat::PackFormat::NONE) { - KAI_ASSUME(src_dt == dst_dt); + KAI_ASSUME((src_dt == dst_dt) || (src_dt == DataType::FP32 && dst_dt == DataType::BF16)); - const auto data_esize = data_type_size_in_bits(dst_dt); + const auto dst_esize = data_type_size_in_bits(dst_dt); + const auto src_esize = data_type_size_in_bits(src_dt); - if (data_esize % 8 == 0) { - return pack_block( - src, data_esize / 8, height, width, block_height, block_width, subblock_height, subblock_width); - } + KAI_ASSUME(src_esize % 8 == 0 && dst_esize % 8 == 0); + + return pack_block( + src, src_dt, dst_dt, src_esize / 8, dst_esize / 8, height, width, block_height, block_width, + subblock_height, subblock_width); } KAI_ERROR("Unsupported operation!"); diff --git a/test/reference/pack.hpp b/test/reference/pack.hpp index 10d76a7f6a289fea0a5217c391cea77b703cedaa..128ad0400b09a4a3ae8e1d503fc8ae0824837ebe 100644 --- a/test/reference/pack.hpp +++ b/test/reference/pack.hpp @@ -22,8 +22,8 @@ class DataFormat; /// @param[in] height Number of rows of the source matrix. /// @param[in] width Number of columns of the source matrix. std::vector pack( - const DataFormat& dst_format, const void* src, const void* scales, const void* zero_points, - const DataFormat& src_format, size_t height, size_t width); + const DataFormat& dst_format, const void* src, const void* scales, const void* bias, const DataFormat& src_format, + size_t height, size_t width); /// Packs the quantized data and the quantization scale into a single buffer. /// diff --git a/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..730ff5aed286f79ab87bac3014fa9b150000f4c6 --- /dev/null +++ b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp @@ -0,0 +1,374 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "kai/kai_common.h" +#include "test/common/compare.hpp" +#include "test/common/cpu_info.hpp" +#include "test/common/data_format.hpp" +#include "test/common/data_type.hpp" +#include "test/common/matmul_test_common.hpp" +#include "test/common/matrix_portion.hpp" +#include "test/common/printer.hpp" +#include "test/reference/fill.hpp" +#include "test/reference/matmul.hpp" +#include "test/reference/pack.hpp" + +// matmul_clamp_f32_bf16p_bf16p +#include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.h" +#include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p_f32_neon.h" +#include "kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.h" +namespace kai::test { + +/// List of supported matrix multiplication methods. +namespace { +const std::array matmul_methods = { + MatMulMethod{ + .name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla", + + .m0 = 8, + .n0 = 12, + .k0 = 4, + + .lhs_transposed = false, + .rhs_transposed = false, + + .dst_format = DataFormat(DataType::FP32), + .lhs_format = DataFormat(DataType::FP32), + .packed_lhs_format = + DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP32, DataType::UNKNOWN, 8, 4), + .rhs_format = DataFormat(DataType::FP32), + .packed_rhs_format = DataFormat( + DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4), + .bias_format = DataFormat(DataType::FP32), + .fn_is_supported = cpu_has_bf16, + + .fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + + .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon, + .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + + .fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p_f32_neon, + .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p_f32_neon, + .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_pack_lhs = kai_run_lhs_quant_pack_bf16p_f32_neon, + + .fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon, + .fn_get_packed_rhs_size = nullptr, + .fn_get_packed_rhs_size_generic_block_size = kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon, + .fn_get_pack_rhs_packed_rhs_offset = nullptr, + .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon, + + .fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon, + + .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + + .fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + }, + MatMulMethod{ + .name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla_opt_bias", + + .m0 = 8, + .n0 = 12, + .k0 = 4, + + .lhs_transposed = false, + .rhs_transposed = false, + + .dst_format = DataFormat(DataType::FP32), + .lhs_format = DataFormat(DataType::FP32), + .packed_lhs_format = + DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP32, DataType::UNKNOWN, 8, 4), + .rhs_format = DataFormat(DataType::FP32), + .packed_rhs_format = DataFormat( + DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4), + .bias_format = DataFormat(DataType::UNKNOWN), + .fn_is_supported = cpu_has_bf16, + + .fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + + .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon, + .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + + .fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p_f32_neon, + .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p_f32_neon, + .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_pack_lhs = kai_run_lhs_quant_pack_bf16p_f32_neon, + + .fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon, + .fn_get_packed_rhs_size = nullptr, + .fn_get_packed_rhs_size_generic_block_size = kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon, + .fn_get_pack_rhs_packed_rhs_offset = nullptr, + .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon, + + .fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon, + + .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + + .fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + }}; +} // namespace + +/// Matrix multiplication test fixture. +class MatMulTestBf16 : public testing::TestWithParam { +private: + /// Unique ID: m, n, k + using TestDataId = std::tuple; + +protected: + /// Cached test data that is shared between multiple test case. + struct TestData { + std::vector lhs{}; ///< LHS operand. + std::vector ref_packed_lhs{}; ///< Reference packed LHS. + std::vector rhs{}; ///< RHS operand. + std::vector rhs_scales{}; ///< RHS per-row quantization scales. + std::vector bias{}; ///< Bias. + std::vector ref_packed_rhs{}; ///< Reference packed RHS. + std::vector ref_dst{}; ///< Reference output. + }; + + /// Gets the test data for the current test case. + static const TestData& test_data() { + const auto& [method, info, portion] = GetParam(); + const TestDataId data_id{info.m, info.n, info.k, method.name}; + + // If the test data is already available, returns it. + const auto data_it = _data.find(data_id); + + if (data_it != _data.end()) { + return data_it->second; + } + + // Generates the test data. + const auto has_lhs_pack = method.packed_lhs_format.data_type() != DataType::UNKNOWN; + const auto has_rhs_pack = method.packed_rhs_format.data_type() != DataType::UNKNOWN; + const auto has_bias = method.bias_format.data_type() != DataType::UNKNOWN; + + const auto lhs_h = method.lhs_transposed ? info.k : info.m; + const auto lhs_w = method.lhs_transposed ? info.m : info.k; + auto lhs = fill_matrix_random(lhs_h, lhs_w, method.lhs_format, 0); + std::vector ref_packed_lhs; + + if (has_lhs_pack) { + ref_packed_lhs = + pack(method.packed_lhs_format, lhs.data(), nullptr, nullptr, method.lhs_format, lhs_h, lhs_w); + } + + const auto rhs_h = method.rhs_transposed ? info.n : info.k; + const auto rhs_w = method.rhs_transposed ? info.k : info.n; + auto rhs = fill_matrix_random(rhs_h, rhs_w, method.rhs_format, 1); + + std::vector rhs_scales; + if (data_type_is_quantized(method.rhs_format.data_type()) && + method.rhs_format.pack_format() == DataFormat::PackFormat::NONE) { + rhs_scales = fill_matrix_random(rhs_h, 1, DataFormat(DataType::FP32), 2); + } + + const auto bias_h = 1; + const auto bias_w = info.n; + std::vector bias; + + if (has_bias) { + bias = fill_matrix_random(bias_h, bias_w, method.bias_format, 3); + } + + constexpr size_t nr = 12; + constexpr size_t kr = 4; + + std::vector packed_rhs; + packed_rhs.resize(method.fn_get_packed_rhs_size_generic_block_size(rhs_w, rhs_h, nr, kr)); + + if (has_rhs_pack) { + const auto ref_rhs_row_stride = method.rhs_format.default_row_stride(rhs_w); + method.pack_rhs( + info.n, info.k, rhs.data(), ref_rhs_row_stride, has_bias ? bias.data() : nullptr, nullptr, + packed_rhs.data()); + } + + KAI_ASSUME(method.lhs_format.is_raw()); + KAI_ASSUME(method.rhs_format.is_raw()); + KAI_ASSUME(method.dst_format.is_raw()); + + auto ref_dst = matmul( + lhs.data(), nullptr, nullptr, method.lhs_format.data_type(), // + rhs.data(), rhs_scales.data(), nullptr, method.rhs_format.data_type(), // + has_bias ? bias.data() : nullptr, nullptr, nullptr, method.bias_format.data_type(), // + method.dst_format.data_type(), // + info.m, info.n, info.k, method.lhs_transposed, method.rhs_transposed); + + const auto& data = _data[data_id] = { + .lhs = std::move(lhs), + .ref_packed_lhs = std::move(ref_packed_lhs), + .rhs = std::move(rhs), + .rhs_scales = std::move(rhs_scales), + .bias = std::move(bias), + .ref_packed_rhs = std::move(packed_rhs), + .ref_dst = std::move(ref_dst), + }; + + return data; + } + +private: + // NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables) + static std::map _data; + // NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables) +}; + +// NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables) +std::map MatMulTestBf16::_data; +// NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables) + +/// Tests the output. +TEST_P(MatMulTestBf16, Output) { + const auto& [method, info, portion] = GetParam(); + const auto& data = test_data(); + + if (method.fn_is_supported && !method.fn_is_supported()) { + GTEST_SKIP(); + } + + if (!method.has_main_kernel()) { + GTEST_SKIP(); + } + + const auto m_step = method.fn_get_main_m_step(); + ASSERT_EQ(m_step, method.m0); + + const auto n_step = method.fn_get_main_n_step(); + ASSERT_EQ(n_step, method.n0); + + const auto rect = portion.compute_portion(info.m, info.n, method.m0, method.n0); + + if (rect.height() == 0 || rect.width() == 0) { + GTEST_SKIP(); + } + + const size_t lhs_w = info.k; + const size_t rhs_w = rect.width(); + const size_t bias_w = info.n; + const size_t dst_w = info.n; + const bool has_bias = (data.bias.size() > 0); + + const auto lhs_start_row = method.lhs_transposed ? 0 : rect.start_row(); + const auto lhs_stride = method.lhs_format.default_row_stride(lhs_w); + + std::vector lhs_data; + const size_t lhs_packed_size = method.fn_get_packed_lhs_size(info.m, info.k, method.m0, method.k0, 1 /* sr */); + lhs_data.resize(lhs_packed_size); + + uintptr_t lhs_offset = method.fn_get_lhs_offset(lhs_start_row, lhs_stride); + uintptr_t lhs_packed_offset = method.fn_get_packed_lhs_offset(lhs_start_row, info.k); + + KAI_UNUSED(lhs_offset); + method.fn_pack_lhs( + rect.height(), info.k, method.m0, method.k0, 1 /* sr */, 0 /* m_idx_start */, data.lhs.data() + lhs_offset, + lhs_stride, lhs_data.data() + lhs_packed_offset); + + const auto rhs_stride = method.rhs_format.default_row_stride(info.n); + + std::vector rhs_data; + const size_t rhs_packed_size = + method.fn_get_packed_rhs_size_generic_block_size(info.n, info.k, method.n0, method.k0); + rhs_data.resize(rhs_packed_size); + + const auto packed_rhs_start_row = rect.start_col(); + const auto packed_rhs_start_col = 0; + + uintptr_t rhs_offset = method.fn_get_rhs_offset(rect.start_col()); + uintptr_t rhs_packed_offset = method.fn_get_main_packed_rhs_offset(packed_rhs_start_row, info.k); + const auto ref_rhs_packed_offset = + method.packed_rhs_format.default_offset_in_bytes(packed_rhs_start_row, packed_rhs_start_col, info.k); + + ASSERT_EQ(rhs_packed_offset, ref_rhs_packed_offset); + + uintptr_t bias_offset = sizeof(float) * rect.start_col(); + + method.fn_pack_rhs( + 1, // num_groups + rhs_w, info.k, method.n0, method.k0, + 1, // sr + rhs_stride, data.rhs.data() + rhs_offset, has_bias ? data.bias.data() + bias_offset : nullptr, + NULL, // Scale + rhs_data.data() + rhs_packed_offset, 0, NULL); + + if (has_bias) { + const auto ref_bias_offset = method.bias_format.default_offset_in_bytes(0, rect.start_col(), bias_w); + ASSERT_EQ(ref_bias_offset, bias_offset); + } + + const auto dst_stride = method.dst_format.default_row_stride(dst_w); + const auto dst_offset = method.fn_get_dst_offset(rect.start_row(), rect.start_col(), dst_stride); + const auto ref_dst_offset = method.dst_format.default_offset_in_bytes(rect.start_row(), rect.start_col(), dst_w); + ASSERT_EQ(dst_offset, ref_dst_offset); + + const auto dst_size = method.fn_get_dst_size(info.m, info.n); + const auto ref_dst_size = method.dst_format.default_size_in_bytes(info.m, info.n); + ASSERT_EQ(dst_size, ref_dst_size); + + std::vector dst; + dst.resize(dst_size); + method.main_kernel( + rect.height(), rect.width(), info.k, lhs_data.data() + lhs_packed_offset, rhs_data.data() + rhs_packed_offset, + NULL, dst.data() + dst_offset, lhs_stride, rhs_stride, dst_stride, -std::numeric_limits::infinity(), + std::numeric_limits::infinity()); + + DefaultMismatchHandler handler(0, 0.02, 0, 0.05); + const auto success = compare(dst.data(), data.ref_dst.data(), method.dst_format, info.m, info.n, rect, handler); + + ASSERT_TRUE(success); +} + +INSTANTIATE_TEST_SUITE_P( + MatMul, MatMulTestBf16, + testing::Combine( + testing::ValuesIn(matmul_methods), + testing::Values( + MatMulShape{3, 7, 3}, // Smaller than block size + MatMulShape{12, 8, 4}, // Same block size + MatMulShape{1, 1, 1023}, // Long K + MatMulShape{1013, 1, 5}, // Long M + MatMulShape{2, 1013, 6}, // Long N + MatMulShape{13, 33, 23}, // + MatMulShape{93, 57, 89}, // + MatMulShape{256, 256, 256}, // Nice shapes + MatMulShape{257, 113, 373} // Prime numbers + ), + testing::Values( + MatrixPortion(0, 0, 1, 1), // Full matrix. + MatrixPortion(0, 0, 0.25, 0.25), // Top-left corner. + MatrixPortion(0.75, 0.75, 1, 1), // Bottom-right corner. + MatrixPortion(0.75, 0, 1, 1), // Partial rows + MatrixPortion(0.4, 0.5, 0.6, 0.8) // Somewhere Middle + )), + testing::PrintToStringParamName()); +} // namespace kai::test diff --git a/test/tests/matmul_test.cpp b/test/tests/matmul_test.cpp index f2d5e54423ae8cffeecc9fb2d31bb5647fc463f8..752a370147de64238ee2a4705257578f1d584323 100644 --- a/test/tests/matmul_test.cpp +++ b/test/tests/matmul_test.cpp @@ -25,7 +25,7 @@ #include "test/common/cpu_info.hpp" #include "test/common/data_format.hpp" #include "test/common/data_type.hpp" -#include "test/common/float16.hpp" +#include "test/common/matmul_test_common.hpp" #include "test/common/matrix_portion.hpp" #include "test/common/printer.hpp" #include "test/common/sme.hpp" @@ -46,296 +46,6 @@ #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.h" namespace kai::test { -// NOLINTBEGIN(misc-non-private-member-variables-in-classes) - -/// Matrix multiplication method. -struct MatMulMethod { - std::string_view name; ///< Name of matmul method. - - size_t m0; ///< Block size in M dimension. - size_t n0; ///< Block size in N dimension. - - bool lhs_transposed; ///< LHS matrix is transposed. - bool rhs_transposed; ///< RHS matrix is transposed. - - DataFormat dst_format; ///< Data format of the destination matrix. - DataFormat lhs_format; ///< Data format of the LHS matrix. - DataFormat packed_lhs_format; ///< Data format of the packed LHS matrix. - DataFormat rhs_format; ///< Data format of the RHS matrix. - DataFormat packed_rhs_format; ///< Data format of the packed RHS matrix. - DataFormat bias_format; ///< Data format of the bias vector. - - /// Check if CPU supports required features. - /// - /// @return Supported (true) or not supported (false). - std::function fn_is_supported; - - /// Gets mr value. - /// - /// This is the packing parameter which must be used to pack the LHS matrix (if necessary). - /// - /// @return The mr value. - std::function fn_get_mr; - - /// Gets nr value. - /// - /// This is the packing parameter which must be used to pack the RHS matrix (if necessary). - /// - /// @return The nr value. - std::function fn_get_nr; - - /// Gets kr value. - /// - /// This is the packing parameter which must be used to pack the LHS and RHS matrix (if necessary). - /// - /// @return The kr value. - std::function fn_get_kr; - - /// Gets sr value. - /// - /// This is the packing parameter which must be used to pack the RHS matrix. - /// - /// @return The sr value. - std::function fn_get_sr; - - /// Gets m step value for main kernel. - /// - /// The starting row index must be divisible by `m_step`. - /// - /// @return The m step value. - std::function fn_get_main_m_step; - - /// Gets n step value for RHS packing kernel. - /// - /// The starting row index must be divisible by `n_step`. - /// - /// @return The n step value. - std::function fn_get_pack_rhs_n_step; - - /// Gets n step value for main kernel. - /// - /// The starting column index must be divisible by `n_step`. - /// - /// @return The n step value. - std::function fn_get_main_n_step; - - /// Gets the offset in bytes of the LHS matrix. - /// - /// @param[in] m_idx Coordinate of the matrix in M dimension. - /// @param[in] stride Row stride in bytes. - /// - /// @return The offset in bytes. - std::function fn_get_lhs_offset; - - /// Gets the size in bytes of the packed LHS matrix. - /// - /// @param[in] m Number of rows in the unpacked LHS matrix. - /// @param[in] k Number of columns in the unpacked LHS matrix. - /// @param[in] mr Number of rows to be interleaved. - /// @param[in] kr Unused. Must be 1. - /// @param[in] sr Unused. Must be 1. - /// - /// @return The size in bytes. - std::function fn_get_packed_lhs_size; - - /// Gets the offset in bytes of the packed LHS matrix. - /// - /// @param[in] m_idx Coordinate of the matrix in M dimension. - /// @param[in] k Size of the matrix in K dimension. - /// - /// @return The offset in bytes. - std::function fn_get_packed_lhs_offset; - - /// Preprocesses the LHS matrix. - /// - /// @param[in] m Number of rows of the unpacked LHS matrix. - /// @param[in] k Common dimension between the LHS and RHS matrix. - /// @param[in] mr Block size in M dimension. It must be {{ kernel.interleave_by }}VL. - /// @param[in] kr Block size in K dimension. It must be {{ kernel.block_by }}. - /// @param[in] sr Number of kr splits. It must be 1. - /// @param[in] m_idx_start Unused. Must be 0. - /// @param[in] lhs LHS matrix data buffer. - /// @param[in] lhs_stride Row stride in bytes of the LHS matrix. - /// @param[out] lhs_packed Packed RHS matrix. - std::function - fn_pack_lhs; - - /// Gets a value indicating whether LHS packing is needed. - [[nodiscard]] bool is_pack_lhs_needed() const { - return fn_pack_lhs != nullptr; - } - - /// Gets the offset in bytes of the RHS matrix. - /// - /// @param[in] n_idx Coordinate of the matrix in N dimension. - /// - /// @return The offset in bytes. - std::function fn_get_rhs_offset; - - /// Gets the size in bytes of the packed RHS matrix. - /// - /// @param[in] n Size of the matrix in N dimension. - /// @param[in] k Size of the matrix in K dimension. - /// - /// @return The size in bytes. - std::function fn_get_packed_rhs_size; - - /// Gets the offset in bytes of the packed RHS matrix in the RHS packing kernel - /// - /// @param[in] n_idx Coordinate of the matrix in N dimension. - /// @param[in] k Size of the matrix in K dimension. - /// - /// @return The offset in bytes. - std::function fn_get_pack_rhs_packed_rhs_offset; - - /// Gets the offset in bytes of the packed RHS matrix in the main kernel. - /// - /// @param[in] n_idx Coordinate of the matrix in N dimension. - /// @param[in] k Size of the matrix in K dimension. - /// - /// @return The offset in bytes. - std::function fn_get_main_packed_rhs_offset; - - std::function - fn_pack_rhs; - - /// Gets the offset in bytes to the data element in the bias buffer. - /// - /// @param[in] n_idx Column index. - /// - /// @return The offset in bytes to the data element. - std::function fn_get_bias_offset; - - /// Gets the offset in bytes to the data element in the destination matrix buffer. - /// - /// @param[in] m_idx Row index. - /// @param[in] n_idx Column index. - /// @param[in] stride Row stride in bytes. - /// - /// @return The offset in bytes to the data element. - std::function fn_get_dst_offset; - - /// Gets the size in bytes of the destination matrix buffer. - /// - /// @param[in] m Number of rows. - /// @param[in] n Number of columns. - /// - /// @return The size in bytes of the destination matrix buffer. - std::function fn_get_dst_size; - - /// Performs F16 or F32 matrix multiplication with RHS packing - /// followed by clamp operation. - /// - /// @param[in] m Size of the matrix in M dimension. - /// @param[in] n Size of the matrix in N dimension. - /// @param[in] k Size of the matrix in K dimension. - /// @param[in] lhs LHS data buffer. - /// @param[in] packed_rhs Packed RHS data buffer. - /// @param[out] dst Output data buffer. - /// @param[in] lhs_stride LHS row stride. - /// @param[in] dst_stride Output row stride. - /// @param[in] clamp_min Lower bound of the output data. - /// @param[in] clamp_max Upper bound of the output data. - std::function - fn_matmul_f16_f16_f16p; - - std::function - fn_matmul_f32_f32_f32p; - - /// Performs F32 matrix multiplication with LHS & RHS packing - /// followed by clamp operation. - /// - /// @param[in] m Number of output rows to be computed. - /// @param[in] n Number of output columns to be computed. - /// @param[in] k Common dimension of the LHS and RHS operands. - /// @param[in] packed_lhs Packed LHS matrix buffer. - /// @param[in] packed_rhs Packed RHS matrix buffer. - /// @param[out] dst Output matrix buffer. - /// @param[in] dst_stride_row Row stride in bytes of the output matrix. - /// @param[in] dst_stride_col Column stride in bytes of the output matrix. - /// @param[in] clamp_min Minimum value to clamp the final result. - /// @param[in] clamp_max Maximum value to clamp the final result. - std::function - fn_matmul_f32_f32p_f32p; - - /// Gets a value indicating whether pre-processing the RHS matrix is needed. - [[nodiscard]] bool is_pack_rhs_needed() const { - return fn_pack_rhs != nullptr; - } - - /// Preprocesses the RHS matrix. - /// - /// @param[in] n Size of the matrix in N dimension. - /// @param[in] k Size of the matrix in K dimension. - /// @param[in] rhs RHS data buffer. - /// @param[in] rhs_row_stride RHS row stride. - /// @param[in] bias Bias data buffer. - /// @param[in] scale Quantization scales data buffer. - /// @param[out] packed_rhs Packed RHS data buffer. - void pack_rhs( - size_t n, size_t k, const void* rhs, size_t rhs_row_stride, const void* bias, const void* scale, - void* packed_rhs) const { - KAI_UNUSED(n); - KAI_UNUSED(k); - KAI_UNUSED(rhs); - KAI_UNUSED(rhs_row_stride); - KAI_UNUSED(bias); - KAI_UNUSED(scale); - KAI_UNUSED(packed_rhs); - - if (fn_pack_rhs != nullptr) { - fn_pack_rhs( - 1, n, k, fn_get_nr(), fn_get_kr(), fn_get_sr(), rhs_row_stride, rhs, bias, nullptr, packed_rhs, 0, - nullptr); - } else { - KAI_ERROR("RHS pre-processing is not supported!"); - } - } - - [[nodiscard]] bool has_main_kernel() const { - return fn_matmul_f16_f16_f16p != nullptr || fn_matmul_f32_f32p_f32p != nullptr || - fn_matmul_f32_f32_f32p != nullptr; - } - - void main_kernel( - size_t m, size_t n, size_t k, const void* lhs, const void* rhs, const void* bias, void* dst, size_t lhs_stride, - size_t rhs_stride, size_t dst_stride, float clamp_min, float clamp_max) const { - KAI_UNUSED(bias); - KAI_UNUSED(rhs_stride); - if (fn_matmul_f16_f16_f16p) { - fn_matmul_f16_f16_f16p( - m, n, k, lhs, lhs_stride, rhs, dst, dst_stride, sizeof(Float16), clamp_min, - static_cast(clamp_max)); - } else if (fn_matmul_f32_f32_f32p) { - fn_matmul_f32_f32_f32p( - m, n, k, lhs, lhs_stride, rhs, dst, dst_stride, sizeof(float), clamp_min, - static_cast(clamp_max)); - } else if (fn_matmul_f32_f32p_f32p) { - fn_matmul_f32_f32p_f32p(m, n, k, lhs, rhs, dst, dst_stride, sizeof(float), clamp_min, clamp_max); - } else { - KAI_ERROR("Main kernel is not available!"); - } - } -}; - -// NOLINTEND(misc-non-private-member-variables-in-classes) - /// List of supported matrix multiplication methods. static const std::array matmul_methods = { MatMulMethod{ @@ -486,35 +196,11 @@ static const std::array matmul_methods = { }, }; -/// Matrix multiplication shape. -struct MatMulShape { - size_t m; ///< LHS height. - size_t n; ///< RHS width. - size_t k; ///< LHS width and RHS height. -}; - -/// Matrix multiplication test information. -using MatMulTestParams = std::tuple; - -/// Prints the test information. -void PrintTo(const MatMulTestParams& param, std::ostream* os) { - const auto& [method_no, shape, portion] = param; - - // NOLINTBEGIN(cppcoreguidelines-pro-bounds-constant-array-index) - *os << "Method_" << matmul_methods[method_no].name // - << "__M_" << shape.m << "__N_" << shape.n << "__K_" << shape.k // - << "__PortionStartRow_" << static_cast(portion.start_row() * 1000) // - << "__PortionStartCol_" << static_cast(portion.start_col() * 1000) // - << "__PortionHeight_" << static_cast(portion.height() * 1000) // - << "__PortionWidth_" << static_cast(portion.width() * 1000); - // NOLINTEND(cppcoreguidelines-pro-bounds-constant-array-index) -} - /// Matrix multiplication test fixture. class MatMulTest : public testing::TestWithParam { private: /// Unique ID: m, n, k, method_id. - using TestDataId = std::tuple; + using TestDataId = std::tuple; protected: /// Cached test data that is shared between multiple test case. @@ -530,8 +216,8 @@ protected: /// Gets the test data for the current test case. static const TestData& test_data() { - const auto& [method_no, info, portion] = GetParam(); - const TestDataId data_id{info.m, info.n, info.k, method_no}; + const auto& [method, info, portion] = GetParam(); + const TestDataId data_id{info.m, info.n, info.k, method.name}; // If the test data is already available, returns it. const auto data_it = _data.find(data_id); @@ -541,8 +227,6 @@ protected: } // Generates the test data. - const auto& method = matmul_methods.at(method_no); - const auto has_lhs_pack = method.packed_lhs_format.data_type() != DataType::UNKNOWN; const auto has_rhs_pack = method.packed_rhs_format.data_type() != DataType::UNKNOWN; const auto has_bias = method.bias_format.data_type() != DataType::UNKNOWN; @@ -617,9 +301,8 @@ std::map MatMulTest::_data; /// Tests the LHS packing kernel. TEST_P(MatMulTest, PackedLhs) { - const auto& [method_no, info, portion] = GetParam(); + const auto& [method, info, portion] = GetParam(); const auto& data = test_data(); - const auto& method = matmul_methods.at(method_no); if (method.fn_is_supported && !method.fn_is_supported()) { GTEST_SKIP(); @@ -668,9 +351,8 @@ TEST_P(MatMulTest, PackedLhs) { /// Tests the RHS packing kernel. TEST_P(MatMulTest, PackedRhs) { - const auto& [method_no, info, portion] = GetParam(); + const auto& [method, info, portion] = GetParam(); const auto& data = test_data(); - const auto& method = matmul_methods.at(method_no); if (method.fn_is_supported && !method.fn_is_supported()) { GTEST_SKIP(); @@ -739,9 +421,8 @@ TEST_P(MatMulTest, PackedRhs) { /// Tests the output. TEST_P(MatMulTest, Output) { - const auto& [method_no, info, portion] = GetParam(); + const auto& [method, info, portion] = GetParam(); const auto& data = test_data(); - const auto& method = matmul_methods.at(method_no); if (method.fn_is_supported && !method.fn_is_supported()) { GTEST_SKIP(); @@ -837,7 +518,7 @@ TEST_P(MatMulTest, Output) { INSTANTIATE_TEST_SUITE_P( MatMul, MatMulTest, testing::Combine( - testing::Range(0, matmul_methods.size()), + testing::ValuesIn(matmul_methods), testing::Values( MatMulShape{1, 16, 16}, // MatMulShape{20, 1, 20}, //