From 2e669c976d29609843c3cd3c7063e565b99506f7 Mon Sep 17 00:00:00 2001 From: Jens Elofsson Date: Wed, 21 Aug 2024 10:34:59 +0200 Subject: [PATCH 01/10] Add FP32 SME2-kernel to the benchmark suite. Signed-off-by: Jens Elofsson --- benchmark/main.cpp | 9 +- benchmark/matmul/matmul_f32.hpp | 99 ++++++++++++++++++- ...kai_matmul_clamp_f32_f32p_f32p_interface.h | 51 ++++++++++ 3 files changed, 153 insertions(+), 6 deletions(-) create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p_f32p_interface.h diff --git a/benchmark/main.cpp b/benchmark/main.cpp index 5d8e7a32..19c7a251 100644 --- a/benchmark/main.cpp +++ b/benchmark/main.cpp @@ -59,11 +59,18 @@ int main(int argc, char** argv) { exit(EXIT_FAILURE); } - kai_matmul matmul_f32; + kai_matmul_f32_qai8_qsi4 matmul_f32; for (int i = 0; i < num_ukernel_variants; i++) { ::benchmark::RegisterBenchmark(ukernel_variants[i].name, matmul_f32, ukernel_variants[i], m, n, k); } +#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE2) + kai_matmul_f32_f32p_f32p_sme sme_kernel; + for (int i = 0; i < num_sme_variants; i++) { + ::benchmark::RegisterBenchmark(sme_variants[i].name, sme_kernel, sme_variants[i], m, n, k)->Iterations(2000); + } +#endif + ::benchmark::RunSpecifiedBenchmarks(); ::benchmark::Shutdown(); return 0; diff --git a/benchmark/matmul/matmul_f32.hpp b/benchmark/matmul/matmul_f32.hpp index e8aa2dfb..76f3e458 100644 --- a/benchmark/matmul/matmul_f32.hpp +++ b/benchmark/matmul/matmul_f32.hpp @@ -8,10 +8,13 @@ #define KAI_BENCHMARK_MATMUL_MATMUL_F32_HPP #include +#include #include #include "benchmark/matmul/matmul_utils.hpp" +#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p_f32p_interface.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.h" @@ -19,8 +22,14 @@ #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp_qsi4cxp_interface.h" +#include "kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.h" #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h" + +const size_t seed_lhs = 4568; +const size_t seed_rhs = seed_lhs + 4; +const size_t seed_bias = seed_rhs + 4; struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp { kai_matmul_clamp_f32_qai8dxp_qsi4cxp_ukernel ukernel; @@ -105,13 +114,34 @@ kai_matmul_ukernel_f32_qa8dxp_qs4cxp ukernel_variants[] = { // Number of micro-kernel variants stored in the array const size_t num_ukernel_variants = sizeof(ukernel_variants) / sizeof(ukernel_variants[0]); -struct kai_matmul { +struct kai_matmul_ukernel_f32_f32p_f32p { + kai_matmul_clamp_f32_f32p_f32p_ukernel ukernel; + std::string name = {}; +}; + +#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE2) +kai_matmul_ukernel_f32_f32p_f32p sme_variants[] = { + {kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + "matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa"}, +}; + +const size_t num_sme_variants = sizeof(sme_variants) / sizeof(sme_variants[0]); +#endif /* defined(__aarch64__) && defined(__ARM_FEATURE_SVE2) */ + +struct kai_matmul_f32_qai8_qsi4 { template void operator()( benchmark::State& state, kai_matmul_ukernel_f32_qa8dxp_qs4cxp variant, size_t m, size_t n, size_t k) const { - const size_t seed_lhs = 4568; - const size_t seed_rhs = seed_lhs + 4; - const size_t lhs_native_size_f32 = m * k * sizeof(float); const size_t rhs_native_size_f32 = n * k * sizeof(float); const size_t rhs_native_size_qs4cx = n * (k / 2) * sizeof(uint8_t); @@ -198,6 +228,65 @@ struct kai_matmul { delete[] rhs_native_mtx_qs4cx; delete[] rhs_scales_f32; } -}; /* struct kai_matmul */ +}; /* kai_matmul_f32_qai8_qsi4 */ + +#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE2) +struct kai_matmul_f32_f32p_f32p_sme { + template + void operator()( + benchmark::State& state, kai_matmul_ukernel_f32_f32p_f32p variant, size_t m, size_t n, size_t k) const { + const size_t lhs_size = m * k; + const size_t rhs_size = n * k; + const size_t bias_size = n; + const size_t dst_size = m * n; + + float* lhs = new float[lhs_size]; + float* rhs = new float[rhs_size]; + float* bias = new float[bias_size]; + + fill_uniform_random(m, k, lhs, seed_lhs); + fill_uniform_random(k, n, rhs, seed_rhs); + fill_uniform_random(1, n, bias, seed_bias); + + const size_t nr = variant.ukernel.get_nr(); + const size_t kr = variant.ukernel.get_kr(); + const size_t sr = variant.ukernel.get_sr(); + + const size_t rhs_packed_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(n, k); + const size_t rhs_packed_cols = nr + k * nr; + const size_t rhs_packed_rows = rhs_packed_size / (rhs_packed_cols * sizeof(float)); + + float* rhs_packed = new float[rhs_packed_size]; + + const size_t lhs_stride = k * sizeof(float); + const size_t rhs_stride = n * sizeof(float); + const size_t dst_stride_row = n * sizeof(float); + const size_t dst_stride_col = sizeof(float); + kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme( + 1, n, k, nr, kr, sr, // Packing arguments + rhs_stride, rhs, bias, NULL, rhs_packed, 0, NULL); + + float* dst = new float[dst_size]; + for (auto _ : state) { + // run matmul + variant.ukernel.run_matmul( + m, n, k, // Dimensions + lhs, // LHS + rhs_packed, // RHS packed + dst, // DST + dst_stride_row, // DST stride (row) + dst_stride_col, // DST stride (col) + FLT_MIN, FLT_MAX // Min and max for the clamp operation + ); + } + + delete[] lhs; + delete[] rhs; + delete[] bias; + delete[] rhs_packed; + delete[] dst; + } +}; /* struct kai_matmul_f32_f32p_f32p_sme */ +#endif /* defined(__aarch64__) && defined(__ARM_FEATURE_SVE2) */ #endif /* KAI_BENCHMARK_MATMUL_MATMUL_F32_HPP */ diff --git a/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p_f32p_interface.h b/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p_f32p_interface.h new file mode 100644 index 00000000..3e40c2f4 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p_f32p_interface.h @@ -0,0 +1,51 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// All micro-kernels variants of the same type share the same interfaces +// In this case, the micro-kernel type is: matmul_clamp_f32_f32p_f32p + +/// Micro-kernel helper functions ("get" methods) +typedef size_t (*kai_matmul_clamp_f32_f32p_f32p_get_m_step_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_f32p_f32p_get_n_step_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_f32p_f32p_get_mr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_f32p_f32p_get_nr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_f32p_f32p_get_kr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_f32p_f32p_get_sr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_f32p_f32p_get_lhs_packed_offset_func_t)(size_t m_idx, size_t k); +typedef size_t (*kai_matmul_clamp_f32_f32p_f32p_get_rhs_packed_offset_func_t)(size_t n_idx, size_t k); +typedef size_t (*kai_matmul_clamp_f32_f32p_f32p_get_dst_offset_func_t)(size_t m_idx, size_t n_idx, size_t dst_stride); +typedef size_t (*kai_matmul_clamp_f32_f32p_f32p_get_dst_size_func_t)(size_t m, size_t n); + +/// Micro-kernel core function ("run" method) +typedef void (*kai_matmul_clamp_f32_f32p_f32p_run_matmul_func_t)( + size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row, + size_t dst_stride_col, float clamp_min, float clamp_max); + +/// Micro-kernel interface +struct kai_matmul_clamp_f32_f32p_f32p_ukernel { + kai_matmul_clamp_f32_f32p_f32p_get_m_step_func_t get_m_step; + kai_matmul_clamp_f32_f32p_f32p_get_n_step_func_t get_n_step; + kai_matmul_clamp_f32_f32p_f32p_get_mr_func_t get_mr; + kai_matmul_clamp_f32_f32p_f32p_get_nr_func_t get_nr; + kai_matmul_clamp_f32_f32p_f32p_get_nr_func_t get_kr; + kai_matmul_clamp_f32_f32p_f32p_get_sr_func_t get_sr; + kai_matmul_clamp_f32_f32p_f32p_get_lhs_packed_offset_func_t get_lhs_packed_offset; + kai_matmul_clamp_f32_f32p_f32p_get_rhs_packed_offset_func_t get_rhs_packed_offset; + kai_matmul_clamp_f32_f32p_f32p_get_dst_offset_func_t get_dst_offset; + kai_matmul_clamp_f32_f32p_f32p_get_dst_size_func_t get_dst_size; + kai_matmul_clamp_f32_f32p_f32p_run_matmul_func_t run_matmul; +}; + +#ifdef __cplusplus +} +#endif -- GitLab From 2fd4236bae81a829190eedf5b3b9dd7dab74ce34 Mon Sep 17 00:00:00 2001 From: Jens Elofsson Date: Wed, 11 Sep 2024 10:06:53 +0200 Subject: [PATCH 02/10] Split the different benchmark kernel types into seperate files. Split the different kernel types into individual source files, and add the proper compile flags for each of the files. Signed-off-by: Jens Elofsson --- CMakeLists.txt | 6 +- benchmark/main.cpp | 15 +- benchmark/matmul/matmul_f32.cpp | 208 ++++++++++++++++ benchmark/matmul/matmul_f32.hpp | 278 +--------------------- benchmark/matmul/matmul_f32_f32p_f32p.cpp | 114 +++++++++ benchmark/matmul/matmul_f32_f32p_f32p.hpp | 20 ++ 6 files changed, 354 insertions(+), 287 deletions(-) create mode 100644 benchmark/matmul/matmul_f32.cpp create mode 100644 benchmark/matmul/matmul_f32_f32p_f32p.cpp create mode 100644 benchmark/matmul/matmul_f32_f32p_f32p.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 7fc0dc2c..ad0c00f6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -225,8 +225,12 @@ if(KLEIDIAI_BUILD_BENCHMARK) include(FetchGBench) add_executable(kleidiai_benchmark - benchmark/main.cpp) + benchmark/main.cpp + benchmark/matmul/matmul_f32_f32p_f32p.cpp + benchmark/matmul/matmul_f32.cpp) + set_source_files_properties(benchmark/matmul/matmul_f32_f32p_f32p.cpp PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+sve+sve2) + set_source_files_properties(benchmark/matmul/matmul_f32.cpp PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+dotprod+i8mm) target_link_libraries( kleidiai_benchmark kleidiai diff --git a/benchmark/main.cpp b/benchmark/main.cpp index 19c7a251..ebc3bafb 100644 --- a/benchmark/main.cpp +++ b/benchmark/main.cpp @@ -4,6 +4,7 @@ // SPDX-License-Identifier: Apache-2.0 // +#include #include #include @@ -11,6 +12,7 @@ #include #include "benchmark/matmul/matmul_f32.hpp" +#include "benchmark/matmul/matmul_f32_f32p_f32p.hpp" void print_usage(char* name) { fprintf(stderr, "Usage:\n"); @@ -59,17 +61,8 @@ int main(int argc, char** argv) { exit(EXIT_FAILURE); } - kai_matmul_f32_qai8_qsi4 matmul_f32; - for (int i = 0; i < num_ukernel_variants; i++) { - ::benchmark::RegisterBenchmark(ukernel_variants[i].name, matmul_f32, ukernel_variants[i], m, n, k); - } - -#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE2) - kai_matmul_f32_f32p_f32p_sme sme_kernel; - for (int i = 0; i < num_sme_variants; i++) { - ::benchmark::RegisterBenchmark(sme_variants[i].name, sme_kernel, sme_variants[i], m, n, k)->Iterations(2000); - } -#endif + kai::bench::matmul_f32_qa8dxp_qs4cxp::RegisterBenchmarks(m, n, k); + kai::bench::matmul_f32_f32p_f32p::RegisterBenchmarks(m, n, k); ::benchmark::RunSpecifiedBenchmarks(); ::benchmark::Shutdown(); diff --git a/benchmark/matmul/matmul_f32.cpp b/benchmark/matmul/matmul_f32.cpp new file mode 100644 index 00000000..789e2f82 --- /dev/null +++ b/benchmark/matmul/matmul_f32.cpp @@ -0,0 +1,208 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include "benchmark/matmul/matmul_utils.hpp" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp_qsi4cxp_interface.h" +#include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h" + +namespace kai::bench::matmul_f32_qa8dxp_qs4cxp { + +const size_t seed_lhs = 4568; +const size_t seed_rhs = seed_lhs + 4; +const size_t seed_bias = seed_rhs + 4; + +struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp { + kai_matmul_clamp_f32_qai8dxp_qsi4cxp_ukernel ukernel; + std::string name = {}; +}; + +kai_matmul_ukernel_f32_qa8dxp_qs4cxp ukernel_variants[] = { + {kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + "matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod"}, + {kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + "matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod"}, + {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + "matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm"}, + {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + "matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm"}, + {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + "matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm"}, + {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + "matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm"}, +}; + +// Number of micro-kernel variants stored in the array +const size_t num_ukernel_variants = sizeof(ukernel_variants) / sizeof(ukernel_variants[0]); + +struct kai_matmul_f32_qai8_qsi4 { + template + void operator()( + benchmark::State& state, kai_matmul_ukernel_f32_qa8dxp_qs4cxp variant, size_t m, size_t n, size_t k) const { + const size_t lhs_native_size_f32 = m * k * sizeof(float); + const size_t rhs_native_size_f32 = n * k * sizeof(float); + const size_t rhs_native_size_qs4cx = n * (k / 2) * sizeof(uint8_t); + const size_t rhs_scales_size_f32 = n * sizeof(float); + + // Allocate the memory + uint8_t* lhs_native_mtx_f32 = new uint8_t[lhs_native_size_f32]; + uint8_t* rhs_native_mtx_f32 = new uint8_t[rhs_native_size_f32]; + uint8_t* rhs_native_mtx_qs4cx = new uint8_t[rhs_native_size_qs4cx]; + uint8_t* rhs_scales_f32 = new uint8_t[rhs_scales_size_f32]; + + fill_uniform_random(m, k, (float*)lhs_native_mtx_f32, seed_lhs); + fill_uniform_random(n, k, (float*)rhs_native_mtx_f32, seed_rhs); + + quant_qs4cx_f32(n, k, (const float*)rhs_native_mtx_f32, rhs_native_mtx_qs4cx, (float*)rhs_scales_f32); + + delete[] rhs_native_mtx_f32; + + // Get the packing parameters + const size_t mr = variant.ukernel.get_mr(); + const size_t nr = variant.ukernel.get_nr(); + const size_t kr = variant.ukernel.get_kr(); + const size_t sr = variant.ukernel.get_sr(); + + // Get the size in bytes for the packed matrices + const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(m, k, mr, kr, sr); + const size_t rhs_packed_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(n, k, nr, kr, sr); + const size_t dst_size = variant.ukernel.get_dst_size(m, n); + + // Allocate the matrices + uint8_t* lhs_packed_mtx_qa8dx = new uint8_t[lhs_packed_size]; + uint8_t* rhs_packed_mtx_qs4cx = new uint8_t[rhs_packed_size]; + uint8_t* dst_act_mtx_f32 = new uint8_t[dst_size]; + + // If the RHS matrix contains constant values, the packing can be performed + // only once + struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params params; + params.lhs_zero_point = 1; + params.rhs_zero_point = 8; + + // RHS packing + kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0( + 1, n, k, nr, kr, sr, // Packing arguments + (const uint8_t*)(rhs_native_mtx_qs4cx), // RHS + NULL, // Bias + (const float*)(rhs_scales_f32), // Scale + rhs_packed_mtx_qs4cx, // RHS packed + 0, ¶ms); + + // LHS packing + kai_run_lhs_quant_pack_qai8dxp_f32( + m, k, mr, kr, sr, 0, // Packing arguments + (const float*)lhs_native_mtx_f32, // LHS + k * sizeof(float), // LHS stride + lhs_packed_mtx_qa8dx); // LHS packed + + const size_t dst_stride = n * sizeof(float); + const size_t lhs_offset = variant.ukernel.get_lhs_packed_offset(0, k); + const size_t rhs_offset = variant.ukernel.get_rhs_packed_offset(0, k); + const size_t dst_offset = variant.ukernel.get_dst_offset(0, 0, dst_stride); + + const void* lhs_ptr = (const void*)((const char*)lhs_packed_mtx_qa8dx + lhs_offset); + const void* rhs_ptr = (const void*)((const char*)rhs_packed_mtx_qs4cx + rhs_offset); + float* dst_ptr = (float*)((uint8_t*)dst_act_mtx_f32 + dst_offset); + + // Matmul + for (auto _ : state) { + variant.ukernel.run_matmul( + m, n, k, // Dimensions + lhs_ptr, // LHS packed + rhs_ptr, // RHS packed + dst_ptr, // DST + dst_stride, // DST stride (row) + sizeof(float), // DST stride (col) + -FLT_MAX, FLT_MAX // Min and max for the clamp operation + ); + } + + delete[] lhs_packed_mtx_qa8dx; + delete[] rhs_packed_mtx_qs4cx; + delete[] dst_act_mtx_f32; + + delete[] lhs_native_mtx_f32; + delete[] rhs_native_mtx_qs4cx; + delete[] rhs_scales_f32; + } +}; /* kai_matmul_f32_qai8_qsi4 */ + +void RegisterBenchmarks(size_t m, size_t n, size_t k) { + kai_matmul_f32_qai8_qsi4 matmul_f32; + for (int i = 0; i < num_ukernel_variants; i++) { + ::benchmark::RegisterBenchmark(ukernel_variants[i].name, matmul_f32, ukernel_variants[i], m, n, k); + } +} + +}; // namespace kai::bench::matmul_f32_qa8dxp_qs4cxp diff --git a/benchmark/matmul/matmul_f32.hpp b/benchmark/matmul/matmul_f32.hpp index 76f3e458..6bc11eae 100644 --- a/benchmark/matmul/matmul_f32.hpp +++ b/benchmark/matmul/matmul_f32.hpp @@ -7,286 +7,14 @@ #ifndef KAI_BENCHMARK_MATMUL_MATMUL_F32_HPP #define KAI_BENCHMARK_MATMUL_MATMUL_F32_HPP -#include #include #include -#include "benchmark/matmul/matmul_utils.hpp" -#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h" -#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p_f32p_interface.h" -#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h" -#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h" -#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.h" -#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h" -#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.h" -#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h" -#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp_qsi4cxp_interface.h" -#include "kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.h" -#include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h" -#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h" -#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h" +namespace kai::bench::matmul_f32_qa8dxp_qs4cxp { -const size_t seed_lhs = 4568; -const size_t seed_rhs = seed_lhs + 4; -const size_t seed_bias = seed_rhs + 4; +void RegisterBenchmarks(size_t m, size_t n, size_t k); -struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp { - kai_matmul_clamp_f32_qai8dxp_qsi4cxp_ukernel ukernel; - std::string name = {}; -}; - -kai_matmul_ukernel_f32_qa8dxp_qs4cxp ukernel_variants[] = { - {kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, - kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, - kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, - kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, - kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, - kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, - kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, - kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, - kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, - "matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod"}, - {kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - "matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod"}, - {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, - kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, - kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, - kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, - kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, - kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, - kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, - kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, - kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, - "matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm"}, - {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, - kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, - kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, - kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, - kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, - kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, - kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, - kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, - kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, - "matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm"}, - {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, - kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, - kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, - kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, - kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, - kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, - kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, - kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, - kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, - "matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm"}, - {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - "matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm"}, -}; - -// Number of micro-kernel variants stored in the array -const size_t num_ukernel_variants = sizeof(ukernel_variants) / sizeof(ukernel_variants[0]); - -struct kai_matmul_ukernel_f32_f32p_f32p { - kai_matmul_clamp_f32_f32p_f32p_ukernel ukernel; - std::string name = {}; -}; - -#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE2) -kai_matmul_ukernel_f32_f32p_f32p sme_variants[] = { - {kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - "matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa"}, -}; - -const size_t num_sme_variants = sizeof(sme_variants) / sizeof(sme_variants[0]); -#endif /* defined(__aarch64__) && defined(__ARM_FEATURE_SVE2) */ - -struct kai_matmul_f32_qai8_qsi4 { - template - void operator()( - benchmark::State& state, kai_matmul_ukernel_f32_qa8dxp_qs4cxp variant, size_t m, size_t n, size_t k) const { - const size_t lhs_native_size_f32 = m * k * sizeof(float); - const size_t rhs_native_size_f32 = n * k * sizeof(float); - const size_t rhs_native_size_qs4cx = n * (k / 2) * sizeof(uint8_t); - const size_t rhs_scales_size_f32 = n * sizeof(float); - - // Allocate the memory - uint8_t* lhs_native_mtx_f32 = new uint8_t[lhs_native_size_f32]; - uint8_t* rhs_native_mtx_f32 = new uint8_t[rhs_native_size_f32]; - uint8_t* rhs_native_mtx_qs4cx = new uint8_t[rhs_native_size_qs4cx]; - uint8_t* rhs_scales_f32 = new uint8_t[rhs_scales_size_f32]; - - fill_uniform_random(m, k, (float*)lhs_native_mtx_f32, seed_lhs); - fill_uniform_random(n, k, (float*)rhs_native_mtx_f32, seed_rhs); - - quant_qs4cx_f32(n, k, (const float*)rhs_native_mtx_f32, rhs_native_mtx_qs4cx, (float*)rhs_scales_f32); - - delete[] rhs_native_mtx_f32; - - // Get the packing parameters - const size_t mr = variant.ukernel.get_mr(); - const size_t nr = variant.ukernel.get_nr(); - const size_t kr = variant.ukernel.get_kr(); - const size_t sr = variant.ukernel.get_sr(); - - // Get the size in bytes for the packed matrices - const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(m, k, mr, kr, sr); - const size_t rhs_packed_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(n, k, nr, kr, sr); - const size_t dst_size = variant.ukernel.get_dst_size(m, n); - - // Allocate the matrices - uint8_t* lhs_packed_mtx_qa8dx = new uint8_t[lhs_packed_size]; - uint8_t* rhs_packed_mtx_qs4cx = new uint8_t[rhs_packed_size]; - uint8_t* dst_act_mtx_f32 = new uint8_t[dst_size]; - - // If the RHS matrix contains constant values, the packing can be performed - // only once - struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params params; - params.lhs_zero_point = 1; - params.rhs_zero_point = 8; - - // RHS packing - kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0( - 1, n, k, nr, kr, sr, // Packing arguments - (const uint8_t*)(rhs_native_mtx_qs4cx), // RHS - NULL, // Bias - (const float*)(rhs_scales_f32), // Scale - rhs_packed_mtx_qs4cx, // RHS packed - 0, ¶ms); - - // LHS packing - kai_run_lhs_quant_pack_qai8dxp_f32( - m, k, mr, kr, sr, 0, // Packing arguments - (const float*)lhs_native_mtx_f32, // LHS - k * sizeof(float), // LHS stride - lhs_packed_mtx_qa8dx); // LHS packed - - const size_t dst_stride = n * sizeof(float); - const size_t lhs_offset = variant.ukernel.get_lhs_packed_offset(0, k); - const size_t rhs_offset = variant.ukernel.get_rhs_packed_offset(0, k); - const size_t dst_offset = variant.ukernel.get_dst_offset(0, 0, dst_stride); - - const void* lhs_ptr = (const void*)((const char*)lhs_packed_mtx_qa8dx + lhs_offset); - const void* rhs_ptr = (const void*)((const char*)rhs_packed_mtx_qs4cx + rhs_offset); - float* dst_ptr = (float*)((uint8_t*)dst_act_mtx_f32 + dst_offset); - - // Matmul - for (auto _ : state) { - variant.ukernel.run_matmul( - m, n, k, // Dimensions - lhs_ptr, // LHS packed - rhs_ptr, // RHS packed - dst_ptr, // DST - dst_stride, // DST stride (row) - sizeof(float), // DST stride (col) - -FLT_MAX, FLT_MAX // Min and max for the clamp operation - ); - } - - delete[] lhs_packed_mtx_qa8dx; - delete[] rhs_packed_mtx_qs4cx; - delete[] dst_act_mtx_f32; - - delete[] lhs_native_mtx_f32; - delete[] rhs_native_mtx_qs4cx; - delete[] rhs_scales_f32; - } -}; /* kai_matmul_f32_qai8_qsi4 */ - -#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE2) -struct kai_matmul_f32_f32p_f32p_sme { - template - void operator()( - benchmark::State& state, kai_matmul_ukernel_f32_f32p_f32p variant, size_t m, size_t n, size_t k) const { - const size_t lhs_size = m * k; - const size_t rhs_size = n * k; - const size_t bias_size = n; - const size_t dst_size = m * n; - - float* lhs = new float[lhs_size]; - float* rhs = new float[rhs_size]; - float* bias = new float[bias_size]; - - fill_uniform_random(m, k, lhs, seed_lhs); - fill_uniform_random(k, n, rhs, seed_rhs); - fill_uniform_random(1, n, bias, seed_bias); - - const size_t nr = variant.ukernel.get_nr(); - const size_t kr = variant.ukernel.get_kr(); - const size_t sr = variant.ukernel.get_sr(); - - const size_t rhs_packed_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(n, k); - const size_t rhs_packed_cols = nr + k * nr; - const size_t rhs_packed_rows = rhs_packed_size / (rhs_packed_cols * sizeof(float)); - - float* rhs_packed = new float[rhs_packed_size]; - - const size_t lhs_stride = k * sizeof(float); - const size_t rhs_stride = n * sizeof(float); - const size_t dst_stride_row = n * sizeof(float); - const size_t dst_stride_col = sizeof(float); - kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme( - 1, n, k, nr, kr, sr, // Packing arguments - rhs_stride, rhs, bias, NULL, rhs_packed, 0, NULL); - - float* dst = new float[dst_size]; - for (auto _ : state) { - // run matmul - variant.ukernel.run_matmul( - m, n, k, // Dimensions - lhs, // LHS - rhs_packed, // RHS packed - dst, // DST - dst_stride_row, // DST stride (row) - dst_stride_col, // DST stride (col) - FLT_MIN, FLT_MAX // Min and max for the clamp operation - ); - } - - delete[] lhs; - delete[] rhs; - delete[] bias; - delete[] rhs_packed; - delete[] dst; - } -}; /* struct kai_matmul_f32_f32p_f32p_sme */ -#endif /* defined(__aarch64__) && defined(__ARM_FEATURE_SVE2) */ +}; // namespace kai::bench::matmul_f32_qa8dxp_qs4cxp #endif /* KAI_BENCHMARK_MATMUL_MATMUL_F32_HPP */ diff --git a/benchmark/matmul/matmul_f32_f32p_f32p.cpp b/benchmark/matmul/matmul_f32_f32p_f32p.cpp new file mode 100644 index 00000000..b4a6d68b --- /dev/null +++ b/benchmark/matmul/matmul_f32_f32p_f32p.cpp @@ -0,0 +1,114 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#error This file must be compiler for AArch64, FEAT_SVE2 +#else // Architectural feature check + +#include "benchmark/matmul/matmul_f32_f32p_f32p.hpp" + +#include + +#include "benchmark/matmul/matmul_utils.hpp" +#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p_f32p_interface.h" +#include "kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h" + +namespace kai::bench::matmul_f32_f32p_f32p { + +const size_t seed_lhs = 4568; +const size_t seed_rhs = seed_lhs + 4; +const size_t seed_bias = seed_rhs + 4; + +struct kai_matmul_ukernel_f32_f32p_f32p { + kai_matmul_clamp_f32_f32p_f32p_ukernel ukernel; + std::string name = {}; +}; + +kai_matmul_ukernel_f32_f32p_f32p sme_variants[] = { + {kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + "matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa"}, +}; + +const size_t num_sme_variants = sizeof(sme_variants) / sizeof(sme_variants[0]); + +struct kai_matmul_f32_f32p_f32p_sme { + template + void operator()( + benchmark::State& state, kai_matmul_ukernel_f32_f32p_f32p variant, size_t m, size_t n, size_t k) const { + const size_t lhs_size = m * k; + const size_t rhs_size = n * k; + const size_t bias_size = n; + const size_t dst_size = m * n; + + float* lhs = new float[lhs_size]; + float* rhs = new float[rhs_size]; + float* bias = new float[bias_size]; + + fill_uniform_random(m, k, lhs, seed_lhs); + fill_uniform_random(k, n, rhs, seed_rhs); + fill_uniform_random(1, n, bias, seed_bias); + + const size_t nr = variant.ukernel.get_nr(); + const size_t kr = variant.ukernel.get_kr(); + const size_t sr = variant.ukernel.get_sr(); + + const size_t rhs_packed_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(n, k); + const size_t rhs_packed_cols = nr + k * nr; + const size_t rhs_packed_rows = rhs_packed_size / (rhs_packed_cols * sizeof(float)); + + float* rhs_packed = new float[rhs_packed_size]; + + const size_t lhs_stride = k * sizeof(float); + const size_t rhs_stride = n * sizeof(float); + const size_t dst_stride_row = n * sizeof(float); + const size_t dst_stride_col = sizeof(float); + kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme( + 1, n, k, nr, kr, sr, // Packing arguments + rhs_stride, rhs, bias, NULL, rhs_packed, 0, NULL); + + float* dst = new float[dst_size]; + for (auto _ : state) { + // run matmul + variant.ukernel.run_matmul( + m, n, k, // Dimensions + lhs, // LHS + rhs_packed, // RHS packed + dst, // DST + dst_stride_row, // DST stride (row) + dst_stride_col, // DST stride (col) + FLT_MIN, FLT_MAX // Min and max for the clamp operation + ); + } + + delete[] lhs; + delete[] rhs; + delete[] bias; + delete[] rhs_packed; + delete[] dst; + } +}; /* struct kai_matmul_f32_f32p_f32p_sme */ + +void RegisterBenchmarks(size_t m, size_t n, size_t k) { + kai_matmul_f32_f32p_f32p_sme sme_kernel; + for (int i = 0; i < num_sme_variants; i++) { + ::benchmark::RegisterBenchmark(sme_variants[i].name, sme_kernel, sme_variants[i], m, n, k)->Iterations(2000); + } +} +} // namespace kai::bench::matmul_f32_f32p_f32p + +#endif /* defined(__aarch64__) && defined(__ARM_FEATURE_SVE2) */ diff --git a/benchmark/matmul/matmul_f32_f32p_f32p.hpp b/benchmark/matmul/matmul_f32_f32p_f32p.hpp new file mode 100644 index 00000000..802835a7 --- /dev/null +++ b/benchmark/matmul/matmul_f32_f32p_f32p.hpp @@ -0,0 +1,20 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#ifndef MATMUL_F32_F32P_F32P +#define MATMUL_F32_F32P_F32P + +#include + +#include + +namespace kai::bench::matmul_f32_f32p_f32p { + +void RegisterBenchmarks(size_t m, size_t n, size_t k); + +}; + +#endif /* MATMUL_F32_F32P_F32P */ -- GitLab From 9e2b7ab11432d3758cae3d3b609581ec5c2f066d Mon Sep 17 00:00:00 2001 From: Jens Elofsson Date: Wed, 18 Sep 2024 15:40:51 +0200 Subject: [PATCH 03/10] Check if the target have SVE2 before registering SME-benchmarks. Signed-off-by: Jens Elofsson --- CMakeLists.txt | 3 ++- benchmark/main.cpp | 5 ++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index ad0c00f6..5079bae2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -227,7 +227,8 @@ if(KLEIDIAI_BUILD_BENCHMARK) add_executable(kleidiai_benchmark benchmark/main.cpp benchmark/matmul/matmul_f32_f32p_f32p.cpp - benchmark/matmul/matmul_f32.cpp) + benchmark/matmul/matmul_f32.cpp + test/common/cpu_info.cpp) set_source_files_properties(benchmark/matmul/matmul_f32_f32p_f32p.cpp PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+sve+sve2) set_source_files_properties(benchmark/matmul/matmul_f32.cpp PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+dotprod+i8mm) diff --git a/benchmark/main.cpp b/benchmark/main.cpp index ebc3bafb..75698532 100644 --- a/benchmark/main.cpp +++ b/benchmark/main.cpp @@ -13,6 +13,7 @@ #include "benchmark/matmul/matmul_f32.hpp" #include "benchmark/matmul/matmul_f32_f32p_f32p.hpp" +#include "test/common/cpu_info.hpp" void print_usage(char* name) { fprintf(stderr, "Usage:\n"); @@ -62,7 +63,9 @@ int main(int argc, char** argv) { } kai::bench::matmul_f32_qa8dxp_qs4cxp::RegisterBenchmarks(m, n, k); - kai::bench::matmul_f32_f32p_f32p::RegisterBenchmarks(m, n, k); + if (kai::test::cpu_has_sme2()) { + kai::bench::matmul_f32_f32p_f32p::RegisterBenchmarks(m, n, k); + } ::benchmark::RunSpecifiedBenchmarks(); ::benchmark::Shutdown(); -- GitLab From f2527fbc3c1a71de7d3345ab7d2b92f4f8b02c84 Mon Sep 17 00:00:00 2001 From: Jens Elofsson Date: Mon, 23 Sep 2024 10:45:35 +0200 Subject: [PATCH 04/10] Address review comments Signed-off-by: Jens Elofsson --- CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 5079bae2..3b6ea9a5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -230,8 +230,8 @@ if(KLEIDIAI_BUILD_BENCHMARK) benchmark/matmul/matmul_f32.cpp test/common/cpu_info.cpp) - set_source_files_properties(benchmark/matmul/matmul_f32_f32p_f32p.cpp PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+sve+sve2) - set_source_files_properties(benchmark/matmul/matmul_f32.cpp PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+dotprod+i8mm) + set_source_files_properties(benchmark/matmul/matmul_f32_f32p_f32p.cpp PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+sve+sve2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) + set_source_files_properties(benchmark/matmul/matmul_f32.cpp PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+dotprod+i8mm${KLEIDIAI_INTERNAL_EXTRA_ARCH}) target_link_libraries( kleidiai_benchmark kleidiai -- GitLab From 577f038d6406a2a0396638987199ecdc50bbd224 Mon Sep 17 00:00:00 2001 From: Jens Elofsson Date: Tue, 24 Sep 2024 10:21:15 +0200 Subject: [PATCH 05/10] Pack the lhs matrix. Signed-off-by: Jens Elofsson --- benchmark/matmul/matmul_f32_f32p_f32p.cpp | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/benchmark/matmul/matmul_f32_f32p_f32p.cpp b/benchmark/matmul/matmul_f32_f32p_f32p.cpp index b4a6d68b..49425842 100644 --- a/benchmark/matmul/matmul_f32_f32p_f32p.cpp +++ b/benchmark/matmul/matmul_f32_f32p_f32p.cpp @@ -63,10 +63,14 @@ struct kai_matmul_f32_f32p_f32p_sme { fill_uniform_random(k, n, rhs, seed_rhs); fill_uniform_random(1, n, bias, seed_bias); + const size_t mr = variant.ukernel.get_mr(); const size_t nr = variant.ukernel.get_nr(); const size_t kr = variant.ukernel.get_kr(); const size_t sr = variant.ukernel.get_sr(); + const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_pack_f32p2vlx1_f32_sme(m, k, mr, kr, sr); + float* lhs_packed = new float[lhs_packed_size]; + const size_t rhs_packed_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(n, k); const size_t rhs_packed_cols = nr + k * nr; const size_t rhs_packed_rows = rhs_packed_size / (rhs_packed_cols * sizeof(float)); @@ -81,12 +85,14 @@ struct kai_matmul_f32_f32p_f32p_sme { 1, n, k, nr, kr, sr, // Packing arguments rhs_stride, rhs, bias, NULL, rhs_packed, 0, NULL); + kai_run_lhs_pack_f32p2vlx1_f32_sme(m, k, mr, kr, sr, 0, lhs, k * sizeof(float), lhs_packed); + float* dst = new float[dst_size]; for (auto _ : state) { // run matmul variant.ukernel.run_matmul( m, n, k, // Dimensions - lhs, // LHS + lhs_packed, // LHS rhs_packed, // RHS packed dst, // DST dst_stride_row, // DST stride (row) @@ -99,6 +105,7 @@ struct kai_matmul_f32_f32p_f32p_sme { delete[] rhs; delete[] bias; delete[] rhs_packed; + delete[] lhs_packed; delete[] dst; } }; /* struct kai_matmul_f32_f32p_f32p_sme */ -- GitLab From 0cce7b837fda02b1efc56d31b1101d051f9307eb Mon Sep 17 00:00:00 2001 From: Jens Elofsson Date: Tue, 24 Sep 2024 10:28:05 +0200 Subject: [PATCH 06/10] Remove unused variables Signed-off-by: Jens Elofsson --- benchmark/matmul/matmul_f32_f32p_f32p.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/benchmark/matmul/matmul_f32_f32p_f32p.cpp b/benchmark/matmul/matmul_f32_f32p_f32p.cpp index 49425842..129ce9d3 100644 --- a/benchmark/matmul/matmul_f32_f32p_f32p.cpp +++ b/benchmark/matmul/matmul_f32_f32p_f32p.cpp @@ -72,8 +72,6 @@ struct kai_matmul_f32_f32p_f32p_sme { float* lhs_packed = new float[lhs_packed_size]; const size_t rhs_packed_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(n, k); - const size_t rhs_packed_cols = nr + k * nr; - const size_t rhs_packed_rows = rhs_packed_size / (rhs_packed_cols * sizeof(float)); float* rhs_packed = new float[rhs_packed_size]; -- GitLab From b450a8e3f293997bc943903c7fe997ab1f2907c8 Mon Sep 17 00:00:00 2001 From: Jens Elofsson Date: Mon, 30 Sep 2024 16:42:27 +0200 Subject: [PATCH 07/10] Address review comments Signed-off-by: Jens Elofsson --- benchmark/main.cpp | 8 +- benchmark/matmul/matmul_f32.cpp | 194 +++++++++++++--------- benchmark/matmul/matmul_f32.hpp | 10 +- benchmark/matmul/matmul_f32_f32p_f32p.cpp | 8 +- benchmark/matmul/matmul_f32_f32p_f32p.hpp | 4 +- 5 files changed, 130 insertions(+), 94 deletions(-) diff --git a/benchmark/main.cpp b/benchmark/main.cpp index 75698532..35021804 100644 --- a/benchmark/main.cpp +++ b/benchmark/main.cpp @@ -13,7 +13,6 @@ #include "benchmark/matmul/matmul_f32.hpp" #include "benchmark/matmul/matmul_f32_f32p_f32p.hpp" -#include "test/common/cpu_info.hpp" void print_usage(char* name) { fprintf(stderr, "Usage:\n"); @@ -62,10 +61,9 @@ int main(int argc, char** argv) { exit(EXIT_FAILURE); } - kai::bench::matmul_f32_qa8dxp_qs4cxp::RegisterBenchmarks(m, n, k); - if (kai::test::cpu_has_sme2()) { - kai::bench::matmul_f32_f32p_f32p::RegisterBenchmarks(m, n, k); - } + kai::bench::matmul_f32_qa8dxp_qs4cxp::dotprod::RegisterBenchmarks(m, n, k); + kai::bench::matmul_f32_qa8dxp_qs4cxp::i8mm::RegisterBenchmarks(m, n, k); + kai::bench::matmul_f32_f32p_f32p::RegisterBenchmarks(m, n, k); ::benchmark::RunSpecifiedBenchmarks(); ::benchmark::Shutdown(); diff --git a/benchmark/matmul/matmul_f32.cpp b/benchmark/matmul/matmul_f32.cpp index 789e2f82..74daa5d0 100644 --- a/benchmark/matmul/matmul_f32.cpp +++ b/benchmark/matmul/matmul_f32.cpp @@ -4,8 +4,15 @@ // SPDX-License-Identifier: Apache-2.0 // +#if !defined(__ARM_FEATURE_DOTPROD) && !defined(__ARM_FEATURE_MATMUL_INT8) +#error "Dotprod and I8mm extensions required to compile this example" +#else + #include +#include +#include + #include "benchmark/matmul/matmul_utils.hpp" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h" @@ -16,6 +23,7 @@ #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp_qsi4cxp_interface.h" #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h" +#include "test/common/cpu_info.hpp" namespace kai::bench::matmul_f32_qa8dxp_qs4cxp { @@ -28,84 +36,6 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp { std::string name = {}; }; -kai_matmul_ukernel_f32_qa8dxp_qs4cxp ukernel_variants[] = { - {kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, - kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, - kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, - kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, - kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, - kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, - kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, - kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, - kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, - "matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod"}, - {kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - "matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod"}, - {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, - kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, - kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, - kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, - kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, - kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, - kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, - kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, - kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, - "matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm"}, - {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, - kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, - kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, - kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, - kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, - kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, - kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, - kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, - kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, - "matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm"}, - {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, - kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, - kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, - kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, - kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, - kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, - kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, - kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, - kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, - "matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm"}, - {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - "matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm"}, -}; - -// Number of micro-kernel variants stored in the array -const size_t num_ukernel_variants = sizeof(ukernel_variants) / sizeof(ukernel_variants[0]); - struct kai_matmul_f32_qai8_qsi4 { template void operator()( @@ -198,11 +128,113 @@ struct kai_matmul_f32_qai8_qsi4 { } }; /* kai_matmul_f32_qai8_qsi4 */ +namespace dotprod { +kai_matmul_ukernel_f32_qa8dxp_qs4cxp ukernel_variants[] = { + {kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + "matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod"}, + {kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + "matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod"}, +}; + +const size_t num_ukernel_variants = sizeof(ukernel_variants) / sizeof(ukernel_variants[0]); + +void RegisterBenchmarks(size_t m, size_t n, size_t k) { + kai::bench::matmul_f32_qa8dxp_qs4cxp::kai_matmul_f32_qai8_qsi4 matmul_f32; + if (kai::test::cpu_has_dotprod()) { + for (int i = 0; i < dotprod::num_ukernel_variants; i++) { + ::benchmark::RegisterBenchmark( + dotprod::ukernel_variants[i].name, matmul_f32, dotprod::ukernel_variants[i], m, n, k); + } + } +} + +} /* namespace dotprod */ + +namespace i8mm { +kai_matmul_ukernel_f32_qa8dxp_qs4cxp ukernel_variants[] = { + {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + "matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm"}, + {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + "matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm"}, + {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + "matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm"}, + {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + "matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm"}, +}; + +const size_t num_ukernel_i8mm_variants = sizeof(ukernel_variants) / sizeof(ukernel_variants[0]); + void RegisterBenchmarks(size_t m, size_t n, size_t k) { - kai_matmul_f32_qai8_qsi4 matmul_f32; - for (int i = 0; i < num_ukernel_variants; i++) { - ::benchmark::RegisterBenchmark(ukernel_variants[i].name, matmul_f32, ukernel_variants[i], m, n, k); + kai::bench::matmul_f32_qa8dxp_qs4cxp::kai_matmul_f32_qai8_qsi4 matmul_f32; + if (kai::test::cpu_has_i8mm()) { + for (int i = 0; i < i8mm::num_ukernel_i8mm_variants; i++) { + ::benchmark::RegisterBenchmark( + i8mm::ukernel_variants[i].name, matmul_f32, i8mm::ukernel_variants[i], m, n, k); + } } } +}; /* namespace i8mm */ }; // namespace kai::bench::matmul_f32_qa8dxp_qs4cxp + +#endif /* !defined(__ARM_FEATURE_DOTPROD) && !defined(__ARM_FEATURE_MATMUL_INT8) */ diff --git a/benchmark/matmul/matmul_f32.hpp b/benchmark/matmul/matmul_f32.hpp index 6bc11eae..06743a94 100644 --- a/benchmark/matmul/matmul_f32.hpp +++ b/benchmark/matmul/matmul_f32.hpp @@ -7,13 +7,17 @@ #ifndef KAI_BENCHMARK_MATMUL_MATMUL_F32_HPP #define KAI_BENCHMARK_MATMUL_MATMUL_F32_HPP -#include - -#include +#include namespace kai::bench::matmul_f32_qa8dxp_qs4cxp { +namespace dotprod { +void RegisterBenchmarks(size_t m, size_t n, size_t k); +}; /* namespace dotprod */ + +namespace i8mm { void RegisterBenchmarks(size_t m, size_t n, size_t k); +}; /* namespace i8mm */ }; // namespace kai::bench::matmul_f32_qa8dxp_qs4cxp diff --git a/benchmark/matmul/matmul_f32_f32p_f32p.cpp b/benchmark/matmul/matmul_f32_f32p_f32p.cpp index 129ce9d3..02c4c695 100644 --- a/benchmark/matmul/matmul_f32_f32p_f32p.cpp +++ b/benchmark/matmul/matmul_f32_f32p_f32p.cpp @@ -17,6 +17,7 @@ #include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p_f32p_interface.h" #include "kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h" +#include "test/common/cpu_info.hpp" namespace kai::bench::matmul_f32_f32p_f32p { @@ -110,8 +111,11 @@ struct kai_matmul_f32_f32p_f32p_sme { void RegisterBenchmarks(size_t m, size_t n, size_t k) { kai_matmul_f32_f32p_f32p_sme sme_kernel; - for (int i = 0; i < num_sme_variants; i++) { - ::benchmark::RegisterBenchmark(sme_variants[i].name, sme_kernel, sme_variants[i], m, n, k)->Iterations(2000); + if (kai::test::cpu_has_sme2()) { + for (int i = 0; i < num_sme_variants; i++) { + ::benchmark::RegisterBenchmark(sme_variants[i].name, sme_kernel, sme_variants[i], m, n, k) + ->Iterations(2000); + } } } } // namespace kai::bench::matmul_f32_f32p_f32p diff --git a/benchmark/matmul/matmul_f32_f32p_f32p.hpp b/benchmark/matmul/matmul_f32_f32p_f32p.hpp index 802835a7..56e130ef 100644 --- a/benchmark/matmul/matmul_f32_f32p_f32p.hpp +++ b/benchmark/matmul/matmul_f32_f32p_f32p.hpp @@ -7,9 +7,7 @@ #ifndef MATMUL_F32_F32P_F32P #define MATMUL_F32_F32P_F32P -#include - -#include +#include namespace kai::bench::matmul_f32_f32p_f32p { -- GitLab From 1be63d9d48415c7387173303d90e6c525b6657cf Mon Sep 17 00:00:00 2001 From: Jens Elofsson Date: Thu, 3 Oct 2024 10:45:00 +0200 Subject: [PATCH 08/10] Address review comments Signed-off-by: Jens Elofsson --- benchmark/matmul/matmul_f32.cpp | 14 ++++---------- benchmark/matmul/matmul_f32.hpp | 5 +---- benchmark/matmul/matmul_f32_f32p_f32p.cpp | 7 ++----- benchmark/matmul/matmul_f32_f32p_f32p.hpp | 5 +---- benchmark/matmul/matmul_utils.hpp | 5 +---- 5 files changed, 9 insertions(+), 27 deletions(-) diff --git a/benchmark/matmul/matmul_f32.cpp b/benchmark/matmul/matmul_f32.cpp index 74daa5d0..e956a7c5 100644 --- a/benchmark/matmul/matmul_f32.cpp +++ b/benchmark/matmul/matmul_f32.cpp @@ -156,14 +156,11 @@ kai_matmul_ukernel_f32_qa8dxp_qs4cxp ukernel_variants[] = { "matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod"}, }; -const size_t num_ukernel_variants = sizeof(ukernel_variants) / sizeof(ukernel_variants[0]); - void RegisterBenchmarks(size_t m, size_t n, size_t k) { kai::bench::matmul_f32_qa8dxp_qs4cxp::kai_matmul_f32_qai8_qsi4 matmul_f32; if (kai::test::cpu_has_dotprod()) { - for (int i = 0; i < dotprod::num_ukernel_variants; i++) { - ::benchmark::RegisterBenchmark( - dotprod::ukernel_variants[i].name, matmul_f32, dotprod::ukernel_variants[i], m, n, k); + for (const auto& variant : dotprod::ukernel_variants) { + ::benchmark::RegisterBenchmark(variant.name, matmul_f32, variant, m, n, k); } } } @@ -222,14 +219,11 @@ kai_matmul_ukernel_f32_qa8dxp_qs4cxp ukernel_variants[] = { "matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm"}, }; -const size_t num_ukernel_i8mm_variants = sizeof(ukernel_variants) / sizeof(ukernel_variants[0]); - void RegisterBenchmarks(size_t m, size_t n, size_t k) { kai::bench::matmul_f32_qa8dxp_qs4cxp::kai_matmul_f32_qai8_qsi4 matmul_f32; if (kai::test::cpu_has_i8mm()) { - for (int i = 0; i < i8mm::num_ukernel_i8mm_variants; i++) { - ::benchmark::RegisterBenchmark( - i8mm::ukernel_variants[i].name, matmul_f32, i8mm::ukernel_variants[i], m, n, k); + for (const auto& variant : i8mm::ukernel_variants) { + ::benchmark::RegisterBenchmark(variant.name, matmul_f32, variant, m, n, k); } } } diff --git a/benchmark/matmul/matmul_f32.hpp b/benchmark/matmul/matmul_f32.hpp index 06743a94..1245ca8a 100644 --- a/benchmark/matmul/matmul_f32.hpp +++ b/benchmark/matmul/matmul_f32.hpp @@ -4,8 +4,7 @@ // SPDX-License-Identifier: Apache-2.0 // -#ifndef KAI_BENCHMARK_MATMUL_MATMUL_F32_HPP -#define KAI_BENCHMARK_MATMUL_MATMUL_F32_HPP +#pragma once #include @@ -20,5 +19,3 @@ void RegisterBenchmarks(size_t m, size_t n, size_t k); }; /* namespace i8mm */ }; // namespace kai::bench::matmul_f32_qa8dxp_qs4cxp - -#endif /* KAI_BENCHMARK_MATMUL_MATMUL_F32_HPP */ diff --git a/benchmark/matmul/matmul_f32_f32p_f32p.cpp b/benchmark/matmul/matmul_f32_f32p_f32p.cpp index 02c4c695..b2795ca0 100644 --- a/benchmark/matmul/matmul_f32_f32p_f32p.cpp +++ b/benchmark/matmul/matmul_f32_f32p_f32p.cpp @@ -45,8 +45,6 @@ kai_matmul_ukernel_f32_f32p_f32p sme_variants[] = { "matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa"}, }; -const size_t num_sme_variants = sizeof(sme_variants) / sizeof(sme_variants[0]); - struct kai_matmul_f32_f32p_f32p_sme { template void operator()( @@ -112,9 +110,8 @@ struct kai_matmul_f32_f32p_f32p_sme { void RegisterBenchmarks(size_t m, size_t n, size_t k) { kai_matmul_f32_f32p_f32p_sme sme_kernel; if (kai::test::cpu_has_sme2()) { - for (int i = 0; i < num_sme_variants; i++) { - ::benchmark::RegisterBenchmark(sme_variants[i].name, sme_kernel, sme_variants[i], m, n, k) - ->Iterations(2000); + for (const auto& variant : sme_variants) { + ::benchmark::RegisterBenchmark(variant.name, sme_kernel, variant, m, n, k)->Iterations(2000); } } } diff --git a/benchmark/matmul/matmul_f32_f32p_f32p.hpp b/benchmark/matmul/matmul_f32_f32p_f32p.hpp index 56e130ef..f05597ce 100644 --- a/benchmark/matmul/matmul_f32_f32p_f32p.hpp +++ b/benchmark/matmul/matmul_f32_f32p_f32p.hpp @@ -4,8 +4,7 @@ // SPDX-License-Identifier: Apache-2.0 // -#ifndef MATMUL_F32_F32P_F32P -#define MATMUL_F32_F32P_F32P +#pragma once #include @@ -14,5 +13,3 @@ namespace kai::bench::matmul_f32_f32p_f32p { void RegisterBenchmarks(size_t m, size_t n, size_t k); }; - -#endif /* MATMUL_F32_F32P_F32P */ diff --git a/benchmark/matmul/matmul_utils.hpp b/benchmark/matmul/matmul_utils.hpp index d2c7e0a3..a19fdbd1 100644 --- a/benchmark/matmul/matmul_utils.hpp +++ b/benchmark/matmul/matmul_utils.hpp @@ -4,8 +4,7 @@ // SPDX-License-Identifier: Apache-2.0 // -#ifndef KAI_BENCHMARK_MATMUL_MATMUL_UTILS_HPP -#define KAI_BENCHMARK_MATMUL_MATMUL_UTILS_HPP +#pragma once #include #include @@ -82,5 +81,3 @@ static void quant_qs4cx_f32(size_t n, size_t k, const float* rhs_f32, uint8_t* r rhs_scales_f32[row_idx] = recip_scale0; } }; - -#endif /* KAI_BENCHMARK_MATMUL_MATMUL_UTILS_HPP */ -- GitLab From be1eb00ddbd5077e11ca40d54b0a527c9dcc6a61 Mon Sep 17 00:00:00 2001 From: Jens Elofsson Date: Mon, 7 Oct 2024 15:56:22 +0200 Subject: [PATCH 09/10] Correct the size of array allocation. Signed-off-by: Jens Elofsson --- benchmark/matmul/matmul_f32_f32p_f32p.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmark/matmul/matmul_f32_f32p_f32p.cpp b/benchmark/matmul/matmul_f32_f32p_f32p.cpp index b2795ca0..64f605b8 100644 --- a/benchmark/matmul/matmul_f32_f32p_f32p.cpp +++ b/benchmark/matmul/matmul_f32_f32p_f32p.cpp @@ -68,11 +68,11 @@ struct kai_matmul_f32_f32p_f32p_sme { const size_t sr = variant.ukernel.get_sr(); const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_pack_f32p2vlx1_f32_sme(m, k, mr, kr, sr); - float* lhs_packed = new float[lhs_packed_size]; + float* lhs_packed = new float[lhs_packed_size / sizeof(float)]; const size_t rhs_packed_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(n, k); - float* rhs_packed = new float[rhs_packed_size]; + float* rhs_packed = new float[rhs_packed_size / sizeof(float)]; const size_t lhs_stride = k * sizeof(float); const size_t rhs_stride = n * sizeof(float); -- GitLab From eaabb1d529c1f061f21e598e8e4677a440813554 Mon Sep 17 00:00:00 2001 From: Jens Elofsson Date: Tue, 8 Oct 2024 15:10:47 +0200 Subject: [PATCH 10/10] Update README.md Signed-off-by: Jens Elofsson --- benchmark/README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/benchmark/README.md b/benchmark/README.md index 10429d66..f409294e 100644 --- a/benchmark/README.md +++ b/benchmark/README.md @@ -26,7 +26,6 @@ $ cmake -DCMAKE_TOOLCHAIN_FILE=/path/to/android-ndk/build/cmake/android.toolchai ## Usage -For now, the only kernel that is included in the benchmarking suite is matmul_clamp_f32_qai8dxp_qsi4cxp. The dimensions of the LHS- and RHS-matrices needs to be specified with the `-m`, `-n` and `-k` options. The shape of the LHS-matrix is MxK, and the shape of the RHS-matrix is KxN. -- GitLab