From e6a3432e2080661939b3ef2f47a44453df2f2109 Mon Sep 17 00:00:00 2001 From: Jakub Sujak Date: Wed, 5 Feb 2025 10:26:23 +0000 Subject: [PATCH 1/5] Extend benchmark tool to support all matrix multiplication micro-kernels * Refactor the benchmark tool to create a generic abstraction that allows for running matrix multiplication micro-kernels with different interfaces. * Extend benchmark support to all matrix multiplication micro-kernels in the library. Signed-off-by: Jakub Sujak --- CHANGELOG.md | 2 + CMakeLists.txt | 14 +- benchmark/main.cpp | 55 +- benchmark/matmul/matmul_benchmark_logic.hpp | 80 +++ benchmark/matmul/matmul_f32.cpp | 254 ---------- benchmark/matmul/matmul_f32.hpp | 21 - benchmark/matmul/matmul_f32_f32p_f32p.cpp | 114 ----- benchmark/matmul/matmul_f32_f32p_f32p.hpp | 15 - benchmark/matmul/matmul_interface.hpp | 74 +++ benchmark/matmul/matmul_registry.cpp | 529 ++++++++++++++++++++ benchmark/matmul/matmul_registry.hpp | 21 + benchmark/matmul/matmul_runner.hpp | 161 ++++++ benchmark/matmul/matmul_utils.hpp | 83 --- 13 files changed, 915 insertions(+), 508 deletions(-) create mode 100644 benchmark/matmul/matmul_benchmark_logic.hpp delete mode 100644 benchmark/matmul/matmul_f32.cpp delete mode 100644 benchmark/matmul/matmul_f32.hpp delete mode 100644 benchmark/matmul/matmul_f32_f32p_f32p.cpp delete mode 100644 benchmark/matmul/matmul_f32_f32p_f32p.hpp create mode 100644 benchmark/matmul/matmul_interface.hpp create mode 100644 benchmark/matmul/matmul_registry.cpp create mode 100644 benchmark/matmul/matmul_registry.hpp create mode 100644 benchmark/matmul/matmul_runner.hpp delete mode 100644 benchmark/matmul/matmul_utils.hpp diff --git a/CHANGELOG.md b/CHANGELOG.md index 070054f0..b664057e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo ## Upcoming Release +- Extend benchmark tool to support all matrix multiplication micro-kernels. + ## v1.4.0 - New Advanced SIMD micro-kernels: diff --git a/CMakeLists.txt b/CMakeLists.txt index 6cf114d9..ea61ab8f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -363,14 +363,18 @@ if(KLEIDIAI_BUILD_BENCHMARK) add_executable(kleidiai_benchmark benchmark/main.cpp - benchmark/matmul/matmul_f32_f32p_f32p.cpp - benchmark/matmul/matmul_f32.cpp - test/common/cpu_info.cpp) + benchmark/matmul/matmul_benchmark_logic.hpp + benchmark/matmul/matmul_interface.hpp + benchmark/matmul/matmul_runner.hpp + benchmark/matmul/matmul_registry.hpp + benchmark/matmul/matmul_registry.cpp + ) target_link_libraries( kleidiai_benchmark - kleidiai - benchmark::benchmark) + kleidiai_test_framework + benchmark::benchmark + ) set(KLEIDIAI_BENCHMARK_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}) target_include_directories(kleidiai_benchmark diff --git a/benchmark/main.cpp b/benchmark/main.cpp index 35021804..36532c19 100644 --- a/benchmark/main.cpp +++ b/benchmark/main.cpp @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -8,33 +8,47 @@ #include #include -#include #include +#include +#include -#include "benchmark/matmul/matmul_f32.hpp" -#include "benchmark/matmul/matmul_f32_f32p_f32p.hpp" +#include "benchmark/matmul/matmul_registry.hpp" +#include "kai/kai_common.h" +namespace { void print_usage(char* name) { - fprintf(stderr, "Usage:\n"); - fprintf(stderr, "%s -m 13 -n 17 -k 18\n", name); - fprintf(stderr, "\n"); - fprintf(stderr, "For additional options:\n"); - fprintf(stderr, "%s --help\n", name); + std::ostringstream oss; + oss << "Usage:\n"; + oss << "\t" << name << " -m 13 -n 17 -k 18 -b 32\n"; + oss << "Options:\n"; + oss << "\t-m,-n,-k"; + oss << "\tMatrix dimensions\n"; + oss << "\t-b"; + oss << "\t\t\t(Optional) Block size for blockwise quantization\n"; + oss << "For additional options:\n"; + oss << "\t--help\n"; + std::cerr << oss.str() << "\n"; } +} // namespace int main(int argc, char** argv) { ::benchmark::Initialize(&argc, argv); + std::ostringstream oss; + oss << "KleidiAI version: v" << kai_get_version() << "\n"; + bool mflag = false; bool nflag = false; bool kflag = false; + bool bflag = false; - size_t m = 0; - size_t n = 0; - size_t k = 0; + int64_t m = 1; + int64_t n = 1; + int64_t k = 1; + int64_t bl = 32; int opt; - while ((opt = getopt(argc, argv, "m:n:k:")) != -1) { + while ((opt = getopt(argc, argv, "m:n:k:b:")) != -1) { switch (opt) { case 'm': m = atoi(optarg); @@ -48,6 +62,10 @@ int main(int argc, char** argv) { k = atoi(optarg); kflag = true; break; + case 'b': + bl = atoi(optarg); + bflag = true; + break; case '?': // Fallthrough default: @@ -61,9 +79,14 @@ int main(int argc, char** argv) { exit(EXIT_FAILURE); } - kai::bench::matmul_f32_qa8dxp_qs4cxp::dotprod::RegisterBenchmarks(m, n, k); - kai::bench::matmul_f32_qa8dxp_qs4cxp::i8mm::RegisterBenchmarks(m, n, k); - kai::bench::matmul_f32_f32p_f32p::RegisterBenchmarks(m, n, k); + if (!bflag) { + oss << "Optional argument -b not specified. Defaulting to block size " << bl << "\n"; + } + std::cerr << oss.str(); + + for (const auto& benchmark : kai::benchmark::matmul_benchmarks) { + benchmark->Args({m, n, k, bl})->ArgNames({"m", "n", "k", "bl"}); + } ::benchmark::RunSpecifiedBenchmarks(); ::benchmark::Shutdown(); diff --git a/benchmark/matmul/matmul_benchmark_logic.hpp b/benchmark/matmul/matmul_benchmark_logic.hpp new file mode 100644 index 00000000..a6ddc99e --- /dev/null +++ b/benchmark/matmul/matmul_benchmark_logic.hpp @@ -0,0 +1,80 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#ifndef KLEIDIAI_BENCHMARK_MATMUL_MATMUL_BENCHMARK_LOGIC_HPP +#define KLEIDIAI_BENCHMARK_MATMUL_MATMUL_BENCHMARK_LOGIC_HPP + +#include +#include +#include +#include + +#include "benchmark/benchmark.h" +#include "matmul_interface.hpp" +#include "matmul_runner.hpp" + +namespace kai::benchmark { +using Buffer = std::vector; +using CpuRequirement = std::function; + +/// High level description of the matrix multiplication operation. +enum class MatMulOp : uint8_t { + GEMM, + GEMV, +}; + +/// Benchmarks a matrix multiplication micro-kernel. +/// +/// @tparam MatMulInterface Interface of the matrix multiplication micro-kernel. +/// @param state State for the benchmark to use. +/// @param matmul_interface Abstraction containing the micro-kernel to run. +/// @param dst_type Output type of the micro-kernel. Required for the micro-kernel to make certain assumptions +/// internally about the stride of the data. +/// @param matmul_op Type of matrix multiplication operation. +/// @param is_cpu_supported Function that checks the CPU feature requirement to run this benchmark. +template +void kai_benchmark_matmul( + ::benchmark::State& state, const MatMulInterface matmul_interface, const DataType dst_type, + const MatMulOp matmul_op, const CpuRequirement& is_cpu_supported) { + if (!is_cpu_supported()) { + state.SkipWithMessage("Unsupported CPU feature"); + } + + const size_t m = state.range(0); + const size_t n = state.range(1); + const size_t k = state.range(2); + const size_t bl = state.range(3); + + if (m > 1 && matmul_op == MatMulOp::GEMV) { + state.SkipWithMessage("GEMV optimized for m=1 only"); + } + + if constexpr (std::is_same_v) { + if (k % bl != 0) { + state.SkipWithMessage("K must be a multiple of block size"); + } + } + + // Create sufficiently large buffers + const size_t lhs_size = m * k * sizeof(uint64_t); + const size_t rhs_size = n * k * sizeof(uint64_t); + const size_t dst_size = m * n * sizeof(uint32_t); + + const Buffer lhs(lhs_size); + const Buffer rhs(rhs_size); + Buffer dst(dst_size); + + MatMulRunner matmul_runner(matmul_interface, dst_type); + matmul_runner.set_mnk(m, n, k); + matmul_runner.set_bl(bl); + + for (auto _ : state) { + matmul_runner.run(lhs.data(), rhs.data(), dst.data()); + } +} +} // namespace kai::benchmark + +#endif // KLEIDIAI_BENCHMARK_MATMUL_MATMUL_BENCHMARK_LOGIC_HPP diff --git a/benchmark/matmul/matmul_f32.cpp b/benchmark/matmul/matmul_f32.cpp deleted file mode 100644 index a796f9be..00000000 --- a/benchmark/matmul/matmul_f32.cpp +++ /dev/null @@ -1,254 +0,0 @@ -// -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates -// -// SPDX-License-Identifier: Apache-2.0 -// - -#include - -#include -#include - -#include "benchmark/matmul/matmul_utils.hpp" -#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h" -#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h" -#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x4_qsi4cxp8x4_8x8x32_neon_dotprod.h" -#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod.h" -#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.h" -#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h" -#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.h" -#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h" -#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp_qsi4cxp_interface.h" -#include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h" -#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h" -#include "test/common/cpu_info.hpp" - -namespace kai::bench::matmul_f32_qa8dxp_qs4cxp { - -const size_t seed_lhs = 4568; -const size_t seed_rhs = seed_lhs + 4; -const size_t seed_bias = seed_rhs + 4; - -struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp { - kai_matmul_clamp_f32_qai8dxp_qsi4cxp_ukernel ukernel; - std::string name = {}; -}; - -struct kai_matmul_f32_qai8_qsi4 { - template - void operator()( - benchmark::State& state, kai_matmul_ukernel_f32_qa8dxp_qs4cxp variant, size_t m, size_t n, size_t k) const { - const size_t lhs_native_size_f32 = m * k * sizeof(float); - const size_t rhs_native_size_f32 = n * k * sizeof(float); - const size_t rhs_native_size_qs4cx = n * (k / 2) * sizeof(uint8_t); - const size_t rhs_scales_size_f32 = n * sizeof(float); - - // Allocate the memory - uint8_t* lhs_native_mtx_f32 = new uint8_t[lhs_native_size_f32]; - uint8_t* rhs_native_mtx_f32 = new uint8_t[rhs_native_size_f32]; - uint8_t* rhs_native_mtx_qs4cx = new uint8_t[rhs_native_size_qs4cx]; - uint8_t* rhs_scales_f32 = new uint8_t[rhs_scales_size_f32]; - - fill_uniform_random(m, k, (float*)lhs_native_mtx_f32, seed_lhs); - fill_uniform_random(n, k, (float*)rhs_native_mtx_f32, seed_rhs); - - quant_qs4cx_f32(n, k, (const float*)rhs_native_mtx_f32, rhs_native_mtx_qs4cx, (float*)rhs_scales_f32); - - delete[] rhs_native_mtx_f32; - - // Get the packing parameters - const size_t mr = variant.ukernel.get_mr(); - const size_t nr = variant.ukernel.get_nr(); - const size_t kr = variant.ukernel.get_kr(); - const size_t sr = variant.ukernel.get_sr(); - - // Get the size in bytes for the packed matrices - const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(m, k, mr, kr, sr); - const size_t rhs_packed_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(n, k, nr, kr, sr); - const size_t dst_size = variant.ukernel.get_dst_size(m, n); - - // Allocate the matrices - uint8_t* lhs_packed_mtx_qa8dx = new uint8_t[lhs_packed_size]; - uint8_t* rhs_packed_mtx_qs4cx = new uint8_t[rhs_packed_size]; - uint8_t* dst_act_mtx_f32 = new uint8_t[dst_size]; - - // If the RHS matrix contains constant values, the packing can be performed - // only once - struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params params; - params.lhs_zero_point = 1; - params.rhs_zero_point = 8; - - // RHS packing - kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0( - 1, n, k, nr, kr, sr, // Packing arguments - (const uint8_t*)(rhs_native_mtx_qs4cx), // RHS - NULL, // Bias - (const float*)(rhs_scales_f32), // Scale - rhs_packed_mtx_qs4cx, // RHS packed - 0, ¶ms); - - // LHS packing - kai_run_lhs_quant_pack_qai8dxp_f32( - m, k, mr, kr, sr, 0, // Packing arguments - (const float*)lhs_native_mtx_f32, // LHS - k * sizeof(float), // LHS stride - lhs_packed_mtx_qa8dx); // LHS packed - - const size_t dst_stride = n * sizeof(float); - const size_t lhs_offset = variant.ukernel.get_lhs_packed_offset(0, k); - const size_t rhs_offset = variant.ukernel.get_rhs_packed_offset(0, k); - const size_t dst_offset = variant.ukernel.get_dst_offset(0, 0, dst_stride); - - const void* lhs_ptr = (const void*)((const char*)lhs_packed_mtx_qa8dx + lhs_offset); - const void* rhs_ptr = (const void*)((const char*)rhs_packed_mtx_qs4cx + rhs_offset); - float* dst_ptr = (float*)((uint8_t*)dst_act_mtx_f32 + dst_offset); - - // Matmul - for (auto _ : state) { - variant.ukernel.run_matmul( - m, n, k, // Dimensions - lhs_ptr, // LHS packed - rhs_ptr, // RHS packed - dst_ptr, // DST - dst_stride, // DST stride (row) - sizeof(float), // DST stride (col) - -FLT_MAX, FLT_MAX // Min and max for the clamp operation - ); - } - - delete[] lhs_packed_mtx_qa8dx; - delete[] rhs_packed_mtx_qs4cx; - delete[] dst_act_mtx_f32; - - delete[] lhs_native_mtx_f32; - delete[] rhs_native_mtx_qs4cx; - delete[] rhs_scales_f32; - } -}; /* kai_matmul_f32_qai8_qsi4 */ - -namespace dotprod { -kai_matmul_ukernel_f32_qa8dxp_qs4cxp ukernel_variants[] = { - {kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, - kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, - kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, - kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, - kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, - kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, - kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, - kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, - kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, - "matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod"}, - {kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - "matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod"}, - {kai_get_m_step_matmul_clamp_f32_qai8dxp4x4_qsi4cxp8x4_8x8x32_neon_dotprod, - kai_get_n_step_matmul_clamp_f32_qai8dxp4x4_qsi4cxp8x4_8x8x32_neon_dotprod, - kai_get_mr_matmul_clamp_f32_qai8dxp4x4_qsi4cxp8x4_8x8x32_neon_dotprod, - kai_get_nr_matmul_clamp_f32_qai8dxp4x4_qsi4cxp8x4_8x8x32_neon_dotprod, - kai_get_kr_matmul_clamp_f32_qai8dxp4x4_qsi4cxp8x4_8x8x32_neon_dotprod, - kai_get_sr_matmul_clamp_f32_qai8dxp4x4_qsi4cxp8x4_8x8x32_neon_dotprod, - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x4_qsi4cxp8x4_8x8x32_neon_dotprod, - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x4_qsi4cxp8x4_8x8x32_neon_dotprod, - kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x4_qsi4cxp8x4_8x8x32_neon_dotprod, - kai_get_dst_size_matmul_clamp_f32_qai8dxp4x4_qsi4cxp8x4_8x8x32_neon_dotprod, - kai_run_matmul_clamp_f32_qai8dxp4x4_qsi4cxp8x4_8x8x32_neon_dotprod, - "matmul_clamp_f32_qai8dxp4x4_qsi4cxp8x4_8x8x32_neon_dotprod"}, - {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod, - kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod, - kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod, - kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod, - kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod, - kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod, - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod, - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod, - kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod, - kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod, - kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod, - "matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod"}, -}; - -void RegisterBenchmarks(size_t m, size_t n, size_t k) { - kai::bench::matmul_f32_qa8dxp_qs4cxp::kai_matmul_f32_qai8_qsi4 matmul_f32; - if (kai::test::cpu_has_dotprod()) { - for (const auto& variant : dotprod::ukernel_variants) { - ::benchmark::RegisterBenchmark(variant.name, matmul_f32, variant, m, n, k); - } - } -} - -} /* namespace dotprod */ - -namespace i8mm { -kai_matmul_ukernel_f32_qa8dxp_qs4cxp ukernel_variants[] = { - {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, - kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, - kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, - kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, - kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, - kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, - kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, - kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, - kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, - "matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm"}, - {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, - kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, - kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, - kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, - kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, - kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, - kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, - kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, - kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, - "matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm"}, - {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, - kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, - kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, - kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, - kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, - kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, - kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, - kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, - kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, - "matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm"}, - {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - "matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm"}, -}; - -void RegisterBenchmarks(size_t m, size_t n, size_t k) { - kai::bench::matmul_f32_qa8dxp_qs4cxp::kai_matmul_f32_qai8_qsi4 matmul_f32; - if (kai::test::cpu_has_i8mm()) { - for (const auto& variant : i8mm::ukernel_variants) { - ::benchmark::RegisterBenchmark(variant.name, matmul_f32, variant, m, n, k); - } - } -} - -}; /* namespace i8mm */ -}; // namespace kai::bench::matmul_f32_qa8dxp_qs4cxp diff --git a/benchmark/matmul/matmul_f32.hpp b/benchmark/matmul/matmul_f32.hpp deleted file mode 100644 index 1245ca8a..00000000 --- a/benchmark/matmul/matmul_f32.hpp +++ /dev/null @@ -1,21 +0,0 @@ -// -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates -// -// SPDX-License-Identifier: Apache-2.0 -// - -#pragma once - -#include - -namespace kai::bench::matmul_f32_qa8dxp_qs4cxp { - -namespace dotprod { -void RegisterBenchmarks(size_t m, size_t n, size_t k); -}; /* namespace dotprod */ - -namespace i8mm { -void RegisterBenchmarks(size_t m, size_t n, size_t k); -}; /* namespace i8mm */ - -}; // namespace kai::bench::matmul_f32_qa8dxp_qs4cxp diff --git a/benchmark/matmul/matmul_f32_f32p_f32p.cpp b/benchmark/matmul/matmul_f32_f32p_f32p.cpp deleted file mode 100644 index 6fe1ebb0..00000000 --- a/benchmark/matmul/matmul_f32_f32p_f32p.cpp +++ /dev/null @@ -1,114 +0,0 @@ -// -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates -// -// SPDX-License-Identifier: Apache-2.0 -// - -#include "benchmark/matmul/matmul_f32_f32p_f32p.hpp" - -#include - -#include "benchmark/matmul/matmul_utils.hpp" -#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h" -#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p_f32p_interface.h" -#include "kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.h" -#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h" -#include "test/common/cpu_info.hpp" - -namespace kai::bench::matmul_f32_f32p_f32p { - -const size_t seed_lhs = 4568; -const size_t seed_rhs = seed_lhs + 4; -const size_t seed_bias = seed_rhs + 4; - -struct kai_matmul_ukernel_f32_f32p_f32p { - kai_matmul_clamp_f32_f32p_f32p_ukernel ukernel; - std::string name = {}; -}; - -kai_matmul_ukernel_f32_f32p_f32p sme_variants[] = { - {kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - "matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa"}, -}; - -struct kai_matmul_f32_f32p_f32p_sme { - template - void operator()( - benchmark::State& state, kai_matmul_ukernel_f32_f32p_f32p variant, size_t m, size_t n, size_t k) const { - const size_t lhs_size = m * k; - const size_t rhs_size = n * k; - const size_t bias_size = n; - const size_t dst_size = m * n; - - float* lhs = new float[lhs_size]; - float* rhs = new float[rhs_size]; - float* bias = new float[bias_size]; - - fill_uniform_random(m, k, lhs, seed_lhs); - fill_uniform_random(k, n, rhs, seed_rhs); - fill_uniform_random(1, n, bias, seed_bias); - - const size_t mr = variant.ukernel.get_mr(); - const size_t nr = variant.ukernel.get_nr(); - const size_t kr = variant.ukernel.get_kr(); - const size_t sr = variant.ukernel.get_sr(); - - const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_pack_f32p2vlx1_f32_sme(m, k, mr, kr, sr); - float* lhs_packed = new float[lhs_packed_size / sizeof(float)]; - - const size_t rhs_packed_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(n, k); - - float* rhs_packed = new float[rhs_packed_size / sizeof(float)]; - - const size_t lhs_stride = k * sizeof(float); - const size_t rhs_stride = n * sizeof(float); - const size_t dst_stride_row = n * sizeof(float); - const size_t dst_stride_col = sizeof(float); - kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme( - 1, n, k, nr, kr, sr, // Packing arguments - rhs_stride, rhs, bias, NULL, rhs_packed, 0, NULL); - - kai_run_lhs_pack_f32p2vlx1_f32_sme(m, k, mr, kr, sr, 0, lhs, k * sizeof(float), lhs_packed); - - float* dst = new float[dst_size]; - for (auto _ : state) { - // run matmul - variant.ukernel.run_matmul( - m, n, k, // Dimensions - lhs_packed, // LHS - rhs_packed, // RHS packed - dst, // DST - dst_stride_row, // DST stride (row) - dst_stride_col, // DST stride (col) - FLT_MIN, FLT_MAX // Min and max for the clamp operation - ); - } - - delete[] lhs; - delete[] rhs; - delete[] bias; - delete[] rhs_packed; - delete[] lhs_packed; - delete[] dst; - } -}; /* struct kai_matmul_f32_f32p_f32p_sme */ - -void RegisterBenchmarks(size_t m, size_t n, size_t k) { - kai_matmul_f32_f32p_f32p_sme sme_kernel; - if (kai::test::cpu_has_sme2()) { - for (const auto& variant : sme_variants) { - ::benchmark::RegisterBenchmark(variant.name, sme_kernel, variant, m, n, k)->Iterations(2000); - } - } -} -} // namespace kai::bench::matmul_f32_f32p_f32p diff --git a/benchmark/matmul/matmul_f32_f32p_f32p.hpp b/benchmark/matmul/matmul_f32_f32p_f32p.hpp deleted file mode 100644 index f05597ce..00000000 --- a/benchmark/matmul/matmul_f32_f32p_f32p.hpp +++ /dev/null @@ -1,15 +0,0 @@ -// -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates -// -// SPDX-License-Identifier: Apache-2.0 -// - -#pragma once - -#include - -namespace kai::bench::matmul_f32_f32p_f32p { - -void RegisterBenchmarks(size_t m, size_t n, size_t k); - -}; diff --git a/benchmark/matmul/matmul_interface.hpp b/benchmark/matmul/matmul_interface.hpp new file mode 100644 index 00000000..3b5f18f8 --- /dev/null +++ b/benchmark/matmul/matmul_interface.hpp @@ -0,0 +1,74 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#ifndef KLEIDIAI_BENCHMARK_MATMUL_MATMUL_INTERFACE_HPP +#define KLEIDIAI_BENCHMARK_MATMUL_MATMUL_INTERFACE_HPP + +#include + +#include "kai/kai_common.h" + +namespace kai::benchmark { + +/// Abstraction for the unspecialized Matrix Multiplication microkernel interface +struct MatMulBaseInterface { + void (*run_matmul)( + size_t m, size_t n, size_t k, // + const void* lhs_packed, // + const void* rhs_packed, // + void* dst, // + size_t dst_stride_row, size_t dst_stride_col, // + float clamp_min, float clamp_max); +}; + +/// Abstraction for the unspecialized Matrix Multiplication microkernel interface with a strided LHS matrix +struct MatMulStridedLhsInterface { + void (*run_matmul)( + size_t m, size_t n, size_t k, // + const void* lhs_packed, // + size_t lhs_stride, // + const void* rhs_packed, // + void* dst, // + size_t dst_stride_row, size_t dst_stride_col, // + float clamp_min, float clamp_max); +}; + +/// Abstraction for the Matrix Multiplication microkernel interface with a floating point destination buffer +struct MatMulFloatInterface { + void (*run_matmul)( + size_t m, size_t n, size_t k, // + const void* lhs_packed, // + const void* rhs_packed, // + float* dst, // + size_t dst_stride_row, size_t dst_stride_col, // + float clamp_min, float clamp_max); +}; + +/// Abstraction for the Matrix Multiplication micro-kernel with static quantization +struct MatMulStaticQuantInterface { + void (*run_matmul)( + size_t m, size_t n, size_t k, // + const void* lhs_packed, // + const void* rhs_packed, // + void* dst, // + size_t dst_stride_row, size_t dst_stride_col, // + const kai_matmul_requantize32_params* params); +}; + +/// Abstraction for the Matrix Multiplication micro-kernel with dynamic blockwise quantization +struct MatMulBlockwiseDynamicQuantInterface { + void (*run_matmul)( + size_t m, size_t n, size_t k, size_t bl, // + const void* lhs_packed, // + const void* rhs_packed, // + float* dst, // + size_t dst_stride_row, size_t dst_stride_col, // + float clamp_min, float clamp_max); +}; + +} // namespace kai::benchmark + +#endif // KLEIDIAI_BENCHMARK_MATMUL_MATMUL_INTERFACE_HPP diff --git a/benchmark/matmul/matmul_registry.cpp b/benchmark/matmul/matmul_registry.cpp new file mode 100644 index 00000000..286a7c84 --- /dev/null +++ b/benchmark/matmul/matmul_registry.cpp @@ -0,0 +1,529 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "matmul_registry.hpp" + +#include +#include +#include + +#include "benchmark/benchmark.h" +#include "matmul_benchmark_logic.hpp" +#include "matmul_interface.hpp" + +// Micro-kernels to register for benchmarking + +// matmul_clamp_f16_bf16p_bf16p +#include "kai/ukernels/matmul/matmul_clamp_f16_bf16p_bf16p/kai_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla.h" + +// matmul_clamp_f16_f16_f16p +#include "kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.h" +#include "kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot.h" + +// matmul_clamp_f16_f16p_f16p +#include "kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h" + +// matmul_clamp_f32_bf16p_bf16p +#include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.h" + +// matmul_clamp_f32_f32_f32p +#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.h" + +// matmul_clamp_f32_f32p_f32p +#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h" + +// matmul_clamp_f32_qai8dxp_qsi4c32p +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p8x4_1x8_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p8x4_4x8_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h" + +// matmul_clamp_f32_qai8dxp_qsi4cxp +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x4_qsi4cxp8x4_8x8x32_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h" + +// matmul_clamp_f32_qai8dxp_qsi8cxp +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h" + +// matmul_clamp_f32_qsi8d32p_qsi4c32p +#include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h" + +// matmul_clamp_fp32_bf16p_bf16p +#include "kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h" + +// matmul_clamp_qai8_qai8_qsi8cxp +#include "kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.h" + +// matmul_clamp_qai8_qai8p_qsi8cxp +#include "kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h" + +namespace kai::benchmark { + +using DataType = test::DataType; + +// matmul_clamp_f16_bf16p_bf16p +inline constexpr MatMulBaseInterface kai_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla_interface{ + .run_matmul = kai_run_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla, +}; + +// matmul_clamp_f16_f16_f16p +inline constexpr MatMulStridedLhsInterface kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot_interface{ + .run_matmul = kai_run_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot, +}; + +inline constexpr MatMulStridedLhsInterface kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla_interface{ + .run_matmul = kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, +}; + +// matmul_clamp_f16_f16p_f16p +inline constexpr MatMulBaseInterface kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa_interface{ + .run_matmul = kai_run_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa, +}; + +// matmul_clamp_f32_bf16p_bf16p +inline constexpr MatMulBaseInterface kai_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot_interface{ + .run_matmul = kai_run_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, +}; + +inline constexpr MatMulBaseInterface kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla_interface{ + .run_matmul = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, +}; + +// matmul_clamp_f32_f32_f32p +inline constexpr MatMulStridedLhsInterface kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla_interface{ + .run_matmul = kai_run_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla, +}; + +inline constexpr MatMulStridedLhsInterface kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla_interface{ + .run_matmul = kai_run_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla, +}; + +inline constexpr MatMulStridedLhsInterface kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla_interface{ + .run_matmul = kai_run_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla, +}; + +// matmul_clamp_f32_f32p_f32p +inline constexpr MatMulBaseInterface kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa_interface{ + .run_matmul = kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, +}; + +// matmul_clamp_f32_qai8dxp_qsi4c32p +inline constexpr MatMulBlockwiseDynamicQuantInterface + kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod_interface{ + .run_matmul = kai_run_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod, + }; + +inline constexpr MatMulBlockwiseDynamicQuantInterface + kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p8x4_1x8_neon_dotprod_interface{ + .run_matmul = kai_run_matmul_clamp_f32_qai8dxp1x4_qsi4c32p8x4_1x8_neon_dotprod, + }; + +inline constexpr MatMulBlockwiseDynamicQuantInterface + kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod_interface{ + .run_matmul = kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + }; + +inline constexpr MatMulBlockwiseDynamicQuantInterface + kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8_neon_dotprod_interface{ + .run_matmul = kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8_neon_dotprod, + }; +inline constexpr MatMulBlockwiseDynamicQuantInterface + kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod_interface{ + .run_matmul = kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + }; + +inline constexpr MatMulBlockwiseDynamicQuantInterface + kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod_interface{ + .run_matmul = kai_run_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod, + }; + +inline constexpr MatMulBlockwiseDynamicQuantInterface + kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p8x4_4x8_neon_dotprod_interface{ + .run_matmul = kai_run_matmul_clamp_f32_qai8dxp4x4_qsi4c32p8x4_4x8_neon_dotprod, + }; + +inline constexpr MatMulBlockwiseDynamicQuantInterface + kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm_interface{ + .run_matmul = kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + }; + +inline constexpr MatMulBlockwiseDynamicQuantInterface + kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm_interface{ + .run_matmul = kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, + }; + +inline constexpr MatMulBlockwiseDynamicQuantInterface + kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm_interface{ + .run_matmul = kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + }; + +// matmul_clamp_f32_qai8dxp_qsi4cxp +inline constexpr MatMulFloatInterface kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa_interface{ + .run_matmul = kai_run_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa, +}; + +inline constexpr MatMulFloatInterface kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot_interface{ + .run_matmul = kai_run_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot, +}; + +inline constexpr MatMulFloatInterface kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod_interface{ + .run_matmul = kai_run_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod, +}; + +inline constexpr MatMulFloatInterface kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod_interface{ + .run_matmul = kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, +}; + +inline constexpr MatMulFloatInterface kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod_interface{ + .run_matmul = kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, +}; + +inline constexpr MatMulFloatInterface kai_matmul_clamp_f32_qai8dxp4x4_qsi4cxp8x4_8x8x32_neon_dotprod_interface{ + .run_matmul = kai_run_matmul_clamp_f32_qai8dxp4x4_qsi4cxp8x4_8x8x32_neon_dotprod, +}; + +inline constexpr MatMulFloatInterface kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod_interface{ + .run_matmul = kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod, +}; + +inline constexpr MatMulFloatInterface kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm_interface{ + .run_matmul = kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, +}; + +inline constexpr MatMulFloatInterface kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm_interface{ + .run_matmul = kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, +}; + +inline constexpr MatMulFloatInterface kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm_interface{ + .run_matmul = kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, +}; + +inline constexpr MatMulFloatInterface kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm_interface{ + .run_matmul = kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, +}; + +// matmul_clamp_f32_qai8dxp_qsi8cxp +inline constexpr MatMulFloatInterface kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod_interface{ + .run_matmul = kai_run_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod, +}; + +inline constexpr MatMulFloatInterface kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod_interface{ + .run_matmul = kai_run_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod, +}; + +inline constexpr MatMulFloatInterface kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod_interface{ + .run_matmul = kai_run_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod, +}; + +inline constexpr MatMulFloatInterface kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm_interface{ + .run_matmul = kai_run_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm, +}; + +// matmul_clamp_f32_qsi8d32p_qsi4c32p +inline constexpr MatMulBlockwiseDynamicQuantInterface + kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa_interface{ + .run_matmul = kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, + }; + +inline constexpr MatMulBlockwiseDynamicQuantInterface + kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot_interface{ + .run_matmul = kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, + }; + +inline constexpr MatMulBlockwiseDynamicQuantInterface + kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod_interface{ + .run_matmul = kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + }; + +inline constexpr MatMulBlockwiseDynamicQuantInterface + kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod_interface{ + .run_matmul = kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + }; + +inline constexpr MatMulBlockwiseDynamicQuantInterface + kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod_interface{ + .run_matmul = kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + }; + +inline constexpr MatMulBlockwiseDynamicQuantInterface + kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm_interface{ + .run_matmul = kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + }; + +inline constexpr MatMulBlockwiseDynamicQuantInterface + kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm_interface{ + .run_matmul = kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + }; + +// matmul_clamp_fp32_bf16p_bf16p +inline constexpr MatMulBaseInterface kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa_interface{ + .run_matmul = kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, +}; + +// matmul_clamp_qai8_qai8_qsi8cxp +inline constexpr MatMulStaticQuantInterface kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot_interface{ + .run_matmul = kai_run_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, +}; + +// matmul_clamp_qai8_qai8p_qsi8cxp +inline constexpr MatMulStaticQuantInterface kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_interface{ + .run_matmul = kai_run_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, +}; + +const std::array matmul_benchmarks{ + // matmul_clamp_f16_bf16p_bf16p + RegisterBenchmark( + "kai_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla", kai_benchmark_matmul, + kai_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla_interface, DataType::FP16, MatMulOp::GEMM, + test::cpu_has_bf16), + + // matmul_clamp_f16_f16_f16p + RegisterBenchmark( + "kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot", kai_benchmark_matmul, + kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot_interface, DataType::FP16, MatMulOp::GEMV, + test::cpu_has_sme2), + RegisterBenchmark( + "kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla", kai_benchmark_matmul, + kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla_interface, DataType::FP16, MatMulOp::GEMM, + test::cpu_has_fp16), + + // matmul_clamp_f16_f16p_f16p + RegisterBenchmark( + "kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa", kai_benchmark_matmul, + kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa_interface, DataType::FP16, MatMulOp::GEMM, + test::cpu_has_sme2), + + // matmul_clamp_f32_bf16p_bf16p + RegisterBenchmark( + "kai_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot", kai_benchmark_matmul, + kai_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot_interface, DataType::FP32, MatMulOp::GEMV, + test::cpu_has_dotprod), + RegisterBenchmark( + "kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla", kai_benchmark_matmul, + kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla_interface, DataType::FP32, MatMulOp::GEMM, + test::cpu_has_i8mm), + + // matmul_clamp_f32_f32_f32p + RegisterBenchmark( + "kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla", kai_benchmark_matmul, + kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla_interface, DataType::FP32, MatMulOp::GEMV, + test::cpu_has_sme2), + RegisterBenchmark( + "kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla", kai_benchmark_matmul, + kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla_interface, DataType::FP32, MatMulOp::GEMM, + test::cpu_has_i8mm), + RegisterBenchmark( + "kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla", kai_benchmark_matmul, + kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla_interface, DataType::FP32, MatMulOp::GEMV, + test::cpu_has_sme2), + + // matmul_clamp_f32_f32p_f32p + RegisterBenchmark( + "kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa", kai_benchmark_matmul, + kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa_interface, DataType::FP32, MatMulOp::GEMM, + test::cpu_has_sme2), + + // matmul_clamp_f32_qai8dxp_qsi4c32p + RegisterBenchmark( + "kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod", + kai_benchmark_matmul, + kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod_interface, DataType::FP32, MatMulOp::GEMV, + test::cpu_has_dotprod), + RegisterBenchmark( + "kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p8x4_1x8_neon_dotprod", + kai_benchmark_matmul, + kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p8x4_1x8_neon_dotprod_interface, DataType::FP32, MatMulOp::GEMV, + test::cpu_has_dotprod), + RegisterBenchmark( + "kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod", + kai_benchmark_matmul, + kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod_interface, DataType::FP32, MatMulOp::GEMV, + test::cpu_has_dotprod), + RegisterBenchmark( + "kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod", + kai_benchmark_matmul, + kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod_interface, DataType::FP32, MatMulOp::GEMV, + test::cpu_has_dotprod), + RegisterBenchmark( + "kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8_neon_dotprod", + kai_benchmark_matmul, + kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8_neon_dotprod_interface, DataType::FP32, MatMulOp::GEMV, + test::cpu_has_dotprod), + RegisterBenchmark( + "kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod", + kai_benchmark_matmul, + kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod_interface, DataType::FP32, MatMulOp::GEMM, + test::cpu_has_dotprod), + RegisterBenchmark( + "kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p8x4_4x8_neon_dotprod", + kai_benchmark_matmul, + kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p8x4_4x8_neon_dotprod_interface, DataType::FP32, MatMulOp::GEMM, + test::cpu_has_dotprod), + RegisterBenchmark( + "kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm", + kai_benchmark_matmul, + kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm_interface, DataType::FP32, MatMulOp::GEMM, + test::cpu_has_i8mm), + RegisterBenchmark( + "kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm", + kai_benchmark_matmul, + kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm_interface, DataType::FP32, MatMulOp::GEMM, + test::cpu_has_i8mm), + RegisterBenchmark( + "kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm", + kai_benchmark_matmul, + kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm_interface, DataType::FP32, MatMulOp::GEMM, + test::cpu_has_i8mm), + + // matmul_clamp_f32_qai8dxp_qsi4cxp + RegisterBenchmark( + "kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa", kai_benchmark_matmul, + kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa_interface, DataType::FP32, MatMulOp::GEMM, + test::cpu_has_sme2), + RegisterBenchmark( + "kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot", kai_benchmark_matmul, + kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot_interface, DataType::FP32, MatMulOp::GEMV, + test::cpu_has_sme2), + RegisterBenchmark( + "kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod", kai_benchmark_matmul, + kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod_interface, DataType::FP32, MatMulOp::GEMV, + test::cpu_has_dotprod), + RegisterBenchmark( + "kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod", kai_benchmark_matmul, + kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod_interface, DataType::FP32, MatMulOp::GEMV, + test::cpu_has_dotprod), + RegisterBenchmark( + "kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod", kai_benchmark_matmul, + kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod_interface, DataType::FP32, MatMulOp::GEMV, + test::cpu_has_dotprod), + RegisterBenchmark( + "kai_matmul_clamp_f32_qai8dxp4x4_qsi4cxp8x4_8x8x32_neon_dotprod", kai_benchmark_matmul, + kai_matmul_clamp_f32_qai8dxp4x4_qsi4cxp8x4_8x8x32_neon_dotprod_interface, DataType::FP32, MatMulOp::GEMM, + test::cpu_has_dotprod), + RegisterBenchmark( + "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod", kai_benchmark_matmul, + kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod_interface, DataType::FP32, MatMulOp::GEMM, + test::cpu_has_dotprod), + RegisterBenchmark( + "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm", kai_benchmark_matmul, + kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm_interface, DataType::FP32, MatMulOp::GEMM, + test::cpu_has_i8mm), + RegisterBenchmark( + "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm", kai_benchmark_matmul, + kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm_interface, DataType::FP32, MatMulOp::GEMM, + test::cpu_has_i8mm), + RegisterBenchmark( + "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm", kai_benchmark_matmul, + kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm_interface, DataType::FP32, MatMulOp::GEMM, + test::cpu_has_i8mm), + RegisterBenchmark( + "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm", kai_benchmark_matmul, + kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm_interface, DataType::FP32, MatMulOp::GEMM, + test::cpu_has_i8mm), + + // matmul_clamp_f32_qai8dxp_qsi8cxp + RegisterBenchmark( + "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod", kai_benchmark_matmul, + kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod_interface, DataType::FP32, MatMulOp::GEMV, + test::cpu_has_dotprod), + RegisterBenchmark( + "kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod", kai_benchmark_matmul, + kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod_interface, DataType::FP32, MatMulOp::GEMV, + test::cpu_has_dotprod), + RegisterBenchmark( + "kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod", kai_benchmark_matmul, + kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod_interface, DataType::FP32, MatMulOp::GEMM, + test::cpu_has_dotprod), + RegisterBenchmark( + "kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm", kai_benchmark_matmul, + kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm_interface, DataType::FP32, MatMulOp::GEMM, + test::cpu_has_i8mm), + + // matmul_clamp_f32_qsi8d32p_qsi4c32p + RegisterBenchmark( + "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa", + kai_benchmark_matmul, + kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa_interface, DataType::FP32, MatMulOp::GEMM, + test::cpu_has_sme2), + RegisterBenchmark( + "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot", + kai_benchmark_matmul, + kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot_interface, DataType::FP32, MatMulOp::GEMV, + test::cpu_has_sme2), + RegisterBenchmark( + "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod", + kai_benchmark_matmul, + kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod_interface, DataType::FP32, MatMulOp::GEMV, + test::cpu_has_dotprod), + RegisterBenchmark( + "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod", + kai_benchmark_matmul, + kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod_interface, DataType::FP32, MatMulOp::GEMV, + test::cpu_has_dotprod), + RegisterBenchmark( + "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod", + kai_benchmark_matmul, + kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod_interface, DataType::FP32, MatMulOp::GEMM, + test::cpu_has_dotprod), + RegisterBenchmark( + "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm", + kai_benchmark_matmul, + kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm_interface, DataType::FP32, MatMulOp::GEMM, + test::cpu_has_i8mm), + RegisterBenchmark( + "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm", + kai_benchmark_matmul, + kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm_interface, DataType::FP32, MatMulOp::GEMM, + test::cpu_has_i8mm), + + // matmul_clamp_fp32_bf16p_bf16p + RegisterBenchmark( + "kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa", kai_benchmark_matmul, + kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa_interface, DataType::FP32, MatMulOp::GEMM, + test::cpu_has_sme2), + + // matmul_clamp_qai8_qai8_qsi8cxp + RegisterBenchmark( + "kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot", kai_benchmark_matmul, + kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot_interface, DataType::QAI8, MatMulOp::GEMV, + test::cpu_has_sme2), + + // matmul_clamp_qai8_qai8p_qsi8cxp + RegisterBenchmark( + "kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa", + kai_benchmark_matmul, + kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_interface, DataType::QAI8, MatMulOp::GEMM, + test::cpu_has_sme2), +}; + +} // namespace kai::benchmark diff --git a/benchmark/matmul/matmul_registry.hpp b/benchmark/matmul/matmul_registry.hpp new file mode 100644 index 00000000..644be5c2 --- /dev/null +++ b/benchmark/matmul/matmul_registry.hpp @@ -0,0 +1,21 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#ifndef KLEIDIAI_BENCHMARK_MATMUL_MATMUL_REGISTRY_HPP +#define KLEIDIAI_BENCHMARK_MATMUL_MATMUL_REGISTRY_HPP + +#include + +#include "benchmark/benchmark.h" + +namespace kai::benchmark { + +/// Array of registered matrix multiplication benchmarks +extern const std::array<::benchmark::internal::Benchmark*, 45> matmul_benchmarks; + +} // namespace kai::benchmark + +#endif // KLEIDIAI_BENCHMARK_MATMUL_MATMUL_REGISTRY_HPP diff --git a/benchmark/matmul/matmul_runner.hpp b/benchmark/matmul/matmul_runner.hpp new file mode 100644 index 00000000..0f04cd65 --- /dev/null +++ b/benchmark/matmul/matmul_runner.hpp @@ -0,0 +1,161 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#ifndef KLEIDIAI_BENCHMARK_MATMUL_MATMUL_RUNNER_HPP +#define KLEIDIAI_BENCHMARK_MATMUL_MATMUL_RUNNER_HPP + +#include +#include +#include +#include + +#include "kai/kai_common.h" +#include "matmul_interface.hpp" + +namespace kai::benchmark { + +using DataType = test::DataType; + +/// Runner for the matrix multiplication micro-kernel. +/// +/// Prepares and executes the run method of the micro-kernel. +/// +/// @tparam MatMulInterface Interface of the matrix multiplication micro-kernel. +template +class MatMulRunner { +public: + /// Constructs a MatMulRunner object. + /// + /// @param matmul_interface Abstraction containing the micro-kernel to run. + /// @param dst_type Output type of the micro-kernel. Required for the micro-kernel to make certain assumptions + /// internally about the stride of the data. + MatMulRunner(const MatMulInterface& matmul_interface, const DataType dst_type) : + matmul_interface_(matmul_interface), dst_type_(dst_type) { + } + + /// Sets the M, N and K dimensions to describe the operand and result matrices. + /// + /// @param m Rows in a non-transposed LHS and DST matrix. + /// @param n Columns in a non-transposed RHS and DST matrix. + /// @param k Columns in a non-transposed LHS matrix, and rows in a non-transposed RHS matrix. + void set_mnk(const size_t m, const size_t n, const size_t k) { + m_ = m; + n_ = n; + k_ = k; + + lhs_stride_ = k_ * data_type_size_in_bits(dst_type_); + dst_stride_row_ = n_ * data_type_size_in_bits(dst_type_) / 8; + dst_stride_col_ = data_type_size_in_bits(dst_type_) / 8; + } + + /// Sets the block size to use. + /// + /// @param bl Block size. Used for micro-kernels with dynamic blockwise quantization. + void set_bl(const size_t bl) { + bl_ = bl; + } + + /// Runs the matrix multiplication micro-kernel. + /// + /// @param lhs Buffer containing LHS matrix data. + /// @param rhs Buffer containing RHS matrix data. + /// @param dst Destination buffer to write to. + void run(const void* lhs, const void* rhs, void* dst); + +private: + MatMulInterface matmul_interface_ = {}; + + DataType dst_type_ = DataType::FP32; + + size_t m_ = 1; + size_t n_ = 1; + size_t k_ = 1; + size_t bl_ = 32; + + size_t lhs_stride_ = 1; + size_t dst_stride_row_ = 1; + size_t dst_stride_col_ = 1; +}; + +/// Runs the matrix multiplication micro-kernel. +/// +/// @param lhs Buffer containing LHS matrix data. +/// @param rhs Buffer containing RHS matrix data. +/// @param dst Destination buffer to write to. +template +void MatMulRunner::run(const void* lhs, const void* rhs, void* dst) { + matmul_interface_.run_matmul( + m_, n_, k_, // + lhs, rhs, dst, // + dst_stride_row_, dst_stride_col_, // + -FLT_MAX, FLT_MAX // + ); +} + +/// Runs the matrix multiplication micro-kernel. Specialized on the strided LHS interface. +/// +/// @param lhs Buffer containing LHS matrix data. +/// @param rhs Buffer containing RHS matrix data. +/// @param dst Destination buffer to write to. +template <> +inline void MatMulRunner::run(const void* lhs, const void* rhs, void* dst) { + matmul_interface_.run_matmul( + m_, n_, k_, // + lhs, lhs_stride_, rhs, dst, // + dst_stride_row_, dst_stride_col_, // + -FLT_MAX, FLT_MAX // + ); +} + +/// Runs the matrix multiplication micro-kernel. Specialized on the interface with a floating point destination buffer. +/// +/// @param lhs Buffer containing LHS matrix data. +/// @param rhs Buffer containing RHS matrix data. +/// @param dst Destination buffer to write to. +template <> +inline void MatMulRunner::run(const void* lhs, const void* rhs, void* dst) { + matmul_interface_.run_matmul( + m_, n_, k_, // + lhs, rhs, static_cast(dst), // + dst_stride_row_, dst_stride_col_, // + -FLT_MAX, FLT_MAX // + ); +} + +/// Runs the matrix multiplication micro-kernel. Specialized on the static quantization interface. +/// +/// @param lhs Buffer containing LHS matrix data. +/// @param rhs Buffer containing RHS matrix data. +/// @param dst Destination buffer to write to. +template <> +inline void MatMulRunner::run(const void* lhs, const void* rhs, void* dst) { + constexpr kai_matmul_requantize32_params params = {INT8_MIN, INT8_MAX, 0}; + matmul_interface_.run_matmul( + m_, n_, k_, // + lhs, rhs, dst, // + dst_stride_row_, dst_stride_col_, // + ¶ms // + ); +} + +/// Runs the matrix multiplication micro-kernel. Specialized on the dynamic blockwise quantization interface. +/// +/// @param lhs Buffer containing LHS matrix data. +/// @param rhs Buffer containing RHS matrix data. +/// @param dst Destination buffer to write to. +template <> +inline void MatMulRunner::run(const void* lhs, const void* rhs, void* dst) { + matmul_interface_.run_matmul( + m_, n_, k_, bl_, // + lhs, rhs, static_cast(dst), // + dst_stride_row_, dst_stride_col_, // + -FLT_MAX, FLT_MAX // + ); +} + +} // namespace kai::benchmark + +#endif // KLEIDIAI_BENCHMARK_MATMUL_MATMUL_RUNNER_HPP diff --git a/benchmark/matmul/matmul_utils.hpp b/benchmark/matmul/matmul_utils.hpp deleted file mode 100644 index a19fdbd1..00000000 --- a/benchmark/matmul/matmul_utils.hpp +++ /dev/null @@ -1,83 +0,0 @@ -// -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates -// -// SPDX-License-Identifier: Apache-2.0 -// - -#pragma once - -#include -#include -#include -#include -#include -#include - -#define INT4_MIN (-8) -#define INT4_MAX (7) - -static void fill_uniform_random(size_t num_rows, size_t num_cols, float* dst, size_t seed) { - std::srand(seed); - - // Fill the array with random values between -1 and 1 - for (int i = 0; i < num_rows * num_cols; i++) { - dst[i] = (float)((double)std::rand() / RAND_MAX) * 2 - 1; - } -} - -static void quant_qs4cx_f32(size_t n, size_t k, const float* rhs_f32, uint8_t* rhs_qs4cx, float* rhs_scales_f32) { - const size_t dst_stride = (k / 2) * sizeof(int8_t); - - for (size_t row_idx = 0; row_idx < n; ++row_idx) { - const float* src_ptr = rhs_f32 + row_idx * k; - - float max0 = -FLT_MAX; - float min0 = FLT_MAX; - - // Find min/max for each channel - for (size_t k_idx = 0; k_idx < k; ++k_idx) { - const float src0_0 = src_ptr[k_idx]; - - max0 = std::max(src0_0, max0); - min0 = std::min(src0_0, min0); - } - - // Maximum/minimum int8 values - const float qmin = (float)INT4_MIN; - const float qmax = (float)INT4_MAX; - - const float rmin0 = std::min(0.0f, min0); - const float rmax0 = std::max(0.0f, max0); - - const float scale0 = rmin0 == rmax0 ? 1.f : (qmax - qmin) / (rmax0 - rmin0); - - // Reciprocal to quantize - const float recip_scale0 = scale0 ? 1.0f / scale0 : 0.0f; - - uint8_t* dst_ptr = rhs_qs4cx + row_idx * dst_stride; - - // Quantize the channels - for (size_t k_idx = 0; k_idx < k; k_idx += 2) { - const float src0_0 = src_ptr[k_idx + 0]; - const float src0_1 = src_ptr[k_idx + 1]; - - // Scale the values - int32_t v0_s32 = (int32_t)(round(src0_0 * scale0)); - int32_t v1_s32 = (int32_t)(round(src0_1 * scale0)); - - // Maximum/minimum int4 values - v0_s32 = std::clamp(v0_s32, INT4_MIN, INT4_MAX); - v1_s32 = std::clamp(v1_s32, INT4_MIN, INT4_MAX); - - int32_t v0_u8 = (uint8_t)(v0_s32 + 8); - int32_t v1_u8 = (uint8_t)(v1_s32 + 8); - - const uint8_t rhs_v0 = (v1_u8 << 4) | v0_u8; - - dst_ptr[0] = rhs_v0; - dst_ptr += sizeof(uint8_t); - } - - rhs_scales_f32[row_idx] = recip_scale0; - } -}; -- GitLab From 7215c5562608600c72b2cc918b3c1219c3cc7493 Mon Sep 17 00:00:00 2001 From: Jakub Sujak Date: Mon, 24 Feb 2025 12:56:15 +0000 Subject: [PATCH 2/5] Tidy up benchmark registration Signed-off-by: Jakub Sujak --- CMakeLists.txt | 18 +++++------------- benchmark/main.cpp | 12 +++++------- benchmark/matmul/matmul_registry.cpp | 14 ++++++++++++-- benchmark/matmul/matmul_registry.hpp | 12 ++++++++---- 4 files changed, 30 insertions(+), 26 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index ea61ab8f..404a54d3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -338,7 +338,6 @@ if(KLEIDIAI_BUILD_TESTS) PRIVATE GTest::gtest_main ) - # Cross-compiling is a common use case which creates a conflict if DISCOVERY_MODE is set to POST_BUILD (by default) # since the host platform does not match the target. Setting the mode to PRE_TEST avoids this conflict. This feature # was added in CMake 3.18 @@ -363,20 +362,13 @@ if(KLEIDIAI_BUILD_BENCHMARK) add_executable(kleidiai_benchmark benchmark/main.cpp - benchmark/matmul/matmul_benchmark_logic.hpp - benchmark/matmul/matmul_interface.hpp - benchmark/matmul/matmul_runner.hpp - benchmark/matmul/matmul_registry.hpp benchmark/matmul/matmul_registry.cpp ) - target_link_libraries( - kleidiai_benchmark - kleidiai_test_framework - benchmark::benchmark + target_link_libraries(kleidiai_benchmark + PRIVATE + kleidiai + kleidiai_test_framework + benchmark::benchmark ) - - set(KLEIDIAI_BENCHMARK_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}) - target_include_directories(kleidiai_benchmark - PRIVATE KLEIDIAI_BENCHMARK_INCLUDE_DIR) endif() diff --git a/benchmark/main.cpp b/benchmark/main.cpp index 36532c19..0f6a0edd 100644 --- a/benchmark/main.cpp +++ b/benchmark/main.cpp @@ -42,10 +42,10 @@ int main(int argc, char** argv) { bool kflag = false; bool bflag = false; - int64_t m = 1; - int64_t n = 1; - int64_t k = 1; - int64_t bl = 32; + size_t m = 1; + size_t n = 1; + size_t k = 1; + size_t bl = 32; int opt; while ((opt = getopt(argc, argv, "m:n:k:b:")) != -1) { @@ -84,9 +84,7 @@ int main(int argc, char** argv) { } std::cerr << oss.str(); - for (const auto& benchmark : kai::benchmark::matmul_benchmarks) { - benchmark->Args({m, n, k, bl})->ArgNames({"m", "n", "k", "bl"}); - } + kai::benchmark::RegisterMatMulBenchmarks({m, n, k}, bl); ::benchmark::RunSpecifiedBenchmarks(); ::benchmark::Shutdown(); diff --git a/benchmark/matmul/matmul_registry.cpp b/benchmark/matmul/matmul_registry.cpp index 286a7c84..aae944f4 100644 --- a/benchmark/matmul/matmul_registry.cpp +++ b/benchmark/matmul/matmul_registry.cpp @@ -7,6 +7,8 @@ #include "matmul_registry.hpp" #include +#include +#include #include #include @@ -88,7 +90,6 @@ #include "kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h" namespace kai::benchmark { - using DataType = test::DataType; // matmul_clamp_f16_bf16p_bf16p @@ -300,7 +301,7 @@ inline constexpr MatMulStaticQuantInterface kai_matmul_clamp_qai8_qai8p2vlx4_qsi .run_matmul = kai_run_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, }; -const std::array matmul_benchmarks{ +inline const std::array matmul_benchmarks{ // matmul_clamp_f16_bf16p_bf16p RegisterBenchmark( "kai_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla", kai_benchmark_matmul, @@ -526,4 +527,13 @@ const std::array matmul_benchmarks{ test::cpu_has_sme2), }; +void RegisterMatMulBenchmarks(const MatMulShape& shape, const size_t bl) { + for (const auto& benchmark : matmul_benchmarks) { + benchmark + ->Args( + {static_cast(shape.m), static_cast(shape.n), static_cast(shape.k), + static_cast(bl)}) + ->ArgNames({"m", "n", "k", "bl"}); + } +} } // namespace kai::benchmark diff --git a/benchmark/matmul/matmul_registry.hpp b/benchmark/matmul/matmul_registry.hpp index 644be5c2..82ebf11d 100644 --- a/benchmark/matmul/matmul_registry.hpp +++ b/benchmark/matmul/matmul_registry.hpp @@ -7,14 +7,18 @@ #ifndef KLEIDIAI_BENCHMARK_MATMUL_MATMUL_REGISTRY_HPP #define KLEIDIAI_BENCHMARK_MATMUL_MATMUL_REGISTRY_HPP -#include +#include -#include "benchmark/benchmark.h" +#include "test/common/matmul_test_common.hpp" namespace kai::benchmark { +using test::MatMulShape; -/// Array of registered matrix multiplication benchmarks -extern const std::array<::benchmark::internal::Benchmark*, 45> matmul_benchmarks; +/// Registers matrix multiplication micro-kernels for benchmarking. +/// +/// @param shape Shape with M, N and K dimensions describing the matrix multiplication problem. +/// @param bl Block size. Used for micro-kernels with dynamic blockwise quantization. +void RegisterMatMulBenchmarks(const MatMulShape& shape, size_t bl); } // namespace kai::benchmark -- GitLab From 9c9f60171996b796f03369f3c6cd12c934c8ca34 Mon Sep 17 00:00:00 2001 From: Jakub Sujak Date: Mon, 24 Feb 2025 13:22:24 +0000 Subject: [PATCH 3/5] Add kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8_neon_i8mm benchmark Signed-off-by: Jakub Sujak --- benchmark/matmul/matmul_registry.cpp | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/benchmark/matmul/matmul_registry.cpp b/benchmark/matmul/matmul_registry.cpp index aae944f4..5dc2fb61 100644 --- a/benchmark/matmul/matmul_registry.cpp +++ b/benchmark/matmul/matmul_registry.cpp @@ -50,6 +50,7 @@ #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p8x4_4x8_neon_dotprod.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8_neon_i8mm.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h" // matmul_clamp_f32_qai8dxp_qsi4cxp @@ -183,6 +184,11 @@ inline constexpr MatMulBlockwiseDynamicQuantInterface .run_matmul = kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, }; +inline constexpr MatMulBlockwiseDynamicQuantInterface + kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8_neon_i8mm_interface{ + .run_matmul = kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8_neon_i8mm, + }; + inline constexpr MatMulBlockwiseDynamicQuantInterface kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm_interface{ .run_matmul = kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, @@ -400,6 +406,11 @@ inline const std::array matmul_benchmarks{ kai_benchmark_matmul, kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm_interface, DataType::FP32, MatMulOp::GEMM, test::cpu_has_i8mm), + RegisterBenchmark( + "kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8_neon_i8mm", + kai_benchmark_matmul, + kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8_neon_i8mm_interface, DataType::FP32, MatMulOp::GEMM, + test::cpu_has_i8mm), RegisterBenchmark( "kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm", kai_benchmark_matmul, -- GitLab From 6319efa18f9e99813e3ae76e848899ba6fcdd6cb Mon Sep 17 00:00:00 2001 From: Jakub Sujak Date: Mon, 24 Feb 2025 13:37:03 +0000 Subject: [PATCH 4/5] Remove unused header Signed-off-by: Jakub Sujak --- benchmark/main.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/benchmark/main.cpp b/benchmark/main.cpp index 0f6a0edd..6ab9ac74 100644 --- a/benchmark/main.cpp +++ b/benchmark/main.cpp @@ -7,7 +7,6 @@ #include #include -#include #include #include #include -- GitLab From 1e7dfa33841b27d8d737ef1f2a896180932e4c1a Mon Sep 17 00:00:00 2001 From: Jakub Sujak Date: Mon, 24 Feb 2025 14:54:58 +0000 Subject: [PATCH 5/5] Ignore diagnostics in third-party dependencies Signed-off-by: Jakub Sujak --- benchmark/main.cpp | 12 +++++++++++- benchmark/matmul/matmul_benchmark_logic.hpp | 12 +++++++++++- benchmark/matmul/matmul_registry.cpp | 12 +++++++++++- 3 files changed, 33 insertions(+), 3 deletions(-) diff --git a/benchmark/main.cpp b/benchmark/main.cpp index 6ab9ac74..5dac84b4 100644 --- a/benchmark/main.cpp +++ b/benchmark/main.cpp @@ -4,7 +4,6 @@ // SPDX-License-Identifier: Apache-2.0 // -#include #include #include @@ -14,6 +13,17 @@ #include "benchmark/matmul/matmul_registry.hpp" #include "kai/kai_common.h" +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wswitch-default" +#endif // __GNUC__ + +#include + +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif // __GNUC__ + namespace { void print_usage(char* name) { std::ostringstream oss; diff --git a/benchmark/matmul/matmul_benchmark_logic.hpp b/benchmark/matmul/matmul_benchmark_logic.hpp index a6ddc99e..cff73a2d 100644 --- a/benchmark/matmul/matmul_benchmark_logic.hpp +++ b/benchmark/matmul/matmul_benchmark_logic.hpp @@ -12,10 +12,20 @@ #include #include -#include "benchmark/benchmark.h" #include "matmul_interface.hpp" #include "matmul_runner.hpp" +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wswitch-default" +#endif // __GNUC__ + +#include + +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif // __GNUC__ + namespace kai::benchmark { using Buffer = std::vector; using CpuRequirement = std::function; diff --git a/benchmark/matmul/matmul_registry.cpp b/benchmark/matmul/matmul_registry.cpp index 5dc2fb61..75d6a48b 100644 --- a/benchmark/matmul/matmul_registry.cpp +++ b/benchmark/matmul/matmul_registry.cpp @@ -12,10 +12,20 @@ #include #include -#include "benchmark/benchmark.h" #include "matmul_benchmark_logic.hpp" #include "matmul_interface.hpp" +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wswitch-default" +#endif // __GNUC__ + +#include + +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif // __GNUC__ + // Micro-kernels to register for benchmarking // matmul_clamp_f16_bf16p_bf16p -- GitLab