diff --git a/CMakeLists.txt b/CMakeLists.txt index 7fc0dc2c93edefb8d8e3f4bec2a43951299cfaa1..3b6ea9a5a53bb322fc6f5f0e776ea864f9181abc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -225,8 +225,13 @@ if(KLEIDIAI_BUILD_BENCHMARK) include(FetchGBench) add_executable(kleidiai_benchmark - benchmark/main.cpp) + benchmark/main.cpp + benchmark/matmul/matmul_f32_f32p_f32p.cpp + benchmark/matmul/matmul_f32.cpp + test/common/cpu_info.cpp) + set_source_files_properties(benchmark/matmul/matmul_f32_f32p_f32p.cpp PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+sve+sve2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) + set_source_files_properties(benchmark/matmul/matmul_f32.cpp PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+dotprod+i8mm${KLEIDIAI_INTERNAL_EXTRA_ARCH}) target_link_libraries( kleidiai_benchmark kleidiai diff --git a/benchmark/README.md b/benchmark/README.md index 10429d6627489ef593bea54f3e7768e5b2ebcc3a..f409294ea1b825d279af301b0fe5bcd11a34436b 100644 --- a/benchmark/README.md +++ b/benchmark/README.md @@ -26,7 +26,6 @@ $ cmake -DCMAKE_TOOLCHAIN_FILE=/path/to/android-ndk/build/cmake/android.toolchai ## Usage -For now, the only kernel that is included in the benchmarking suite is matmul_clamp_f32_qai8dxp_qsi4cxp. The dimensions of the LHS- and RHS-matrices needs to be specified with the `-m`, `-n` and `-k` options. The shape of the LHS-matrix is MxK, and the shape of the RHS-matrix is KxN. diff --git a/benchmark/main.cpp b/benchmark/main.cpp index 5d8e7a32884e1bbfb7b515d23350283421bc71e0..350218045e3f224ae6d84887a702e1aba0a45ef0 100644 --- a/benchmark/main.cpp +++ b/benchmark/main.cpp @@ -4,6 +4,7 @@ // SPDX-License-Identifier: Apache-2.0 // +#include #include #include @@ -11,6 +12,7 @@ #include #include "benchmark/matmul/matmul_f32.hpp" +#include "benchmark/matmul/matmul_f32_f32p_f32p.hpp" void print_usage(char* name) { fprintf(stderr, "Usage:\n"); @@ -59,10 +61,9 @@ int main(int argc, char** argv) { exit(EXIT_FAILURE); } - kai_matmul matmul_f32; - for (int i = 0; i < num_ukernel_variants; i++) { - ::benchmark::RegisterBenchmark(ukernel_variants[i].name, matmul_f32, ukernel_variants[i], m, n, k); - } + kai::bench::matmul_f32_qa8dxp_qs4cxp::dotprod::RegisterBenchmarks(m, n, k); + kai::bench::matmul_f32_qa8dxp_qs4cxp::i8mm::RegisterBenchmarks(m, n, k); + kai::bench::matmul_f32_f32p_f32p::RegisterBenchmarks(m, n, k); ::benchmark::RunSpecifiedBenchmarks(); ::benchmark::Shutdown(); diff --git a/benchmark/matmul/matmul_f32.cpp b/benchmark/matmul/matmul_f32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e956a7c58c4b7e277c5f59be1b7fd26c30e68dd8 --- /dev/null +++ b/benchmark/matmul/matmul_f32.cpp @@ -0,0 +1,234 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if !defined(__ARM_FEATURE_DOTPROD) && !defined(__ARM_FEATURE_MATMUL_INT8) +#error "Dotprod and I8mm extensions required to compile this example" +#else + +#include + +#include +#include + +#include "benchmark/matmul/matmul_utils.hpp" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp_qsi4cxp_interface.h" +#include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h" +#include "test/common/cpu_info.hpp" + +namespace kai::bench::matmul_f32_qa8dxp_qs4cxp { + +const size_t seed_lhs = 4568; +const size_t seed_rhs = seed_lhs + 4; +const size_t seed_bias = seed_rhs + 4; + +struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp { + kai_matmul_clamp_f32_qai8dxp_qsi4cxp_ukernel ukernel; + std::string name = {}; +}; + +struct kai_matmul_f32_qai8_qsi4 { + template + void operator()( + benchmark::State& state, kai_matmul_ukernel_f32_qa8dxp_qs4cxp variant, size_t m, size_t n, size_t k) const { + const size_t lhs_native_size_f32 = m * k * sizeof(float); + const size_t rhs_native_size_f32 = n * k * sizeof(float); + const size_t rhs_native_size_qs4cx = n * (k / 2) * sizeof(uint8_t); + const size_t rhs_scales_size_f32 = n * sizeof(float); + + // Allocate the memory + uint8_t* lhs_native_mtx_f32 = new uint8_t[lhs_native_size_f32]; + uint8_t* rhs_native_mtx_f32 = new uint8_t[rhs_native_size_f32]; + uint8_t* rhs_native_mtx_qs4cx = new uint8_t[rhs_native_size_qs4cx]; + uint8_t* rhs_scales_f32 = new uint8_t[rhs_scales_size_f32]; + + fill_uniform_random(m, k, (float*)lhs_native_mtx_f32, seed_lhs); + fill_uniform_random(n, k, (float*)rhs_native_mtx_f32, seed_rhs); + + quant_qs4cx_f32(n, k, (const float*)rhs_native_mtx_f32, rhs_native_mtx_qs4cx, (float*)rhs_scales_f32); + + delete[] rhs_native_mtx_f32; + + // Get the packing parameters + const size_t mr = variant.ukernel.get_mr(); + const size_t nr = variant.ukernel.get_nr(); + const size_t kr = variant.ukernel.get_kr(); + const size_t sr = variant.ukernel.get_sr(); + + // Get the size in bytes for the packed matrices + const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(m, k, mr, kr, sr); + const size_t rhs_packed_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(n, k, nr, kr, sr); + const size_t dst_size = variant.ukernel.get_dst_size(m, n); + + // Allocate the matrices + uint8_t* lhs_packed_mtx_qa8dx = new uint8_t[lhs_packed_size]; + uint8_t* rhs_packed_mtx_qs4cx = new uint8_t[rhs_packed_size]; + uint8_t* dst_act_mtx_f32 = new uint8_t[dst_size]; + + // If the RHS matrix contains constant values, the packing can be performed + // only once + struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params params; + params.lhs_zero_point = 1; + params.rhs_zero_point = 8; + + // RHS packing + kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0( + 1, n, k, nr, kr, sr, // Packing arguments + (const uint8_t*)(rhs_native_mtx_qs4cx), // RHS + NULL, // Bias + (const float*)(rhs_scales_f32), // Scale + rhs_packed_mtx_qs4cx, // RHS packed + 0, ¶ms); + + // LHS packing + kai_run_lhs_quant_pack_qai8dxp_f32( + m, k, mr, kr, sr, 0, // Packing arguments + (const float*)lhs_native_mtx_f32, // LHS + k * sizeof(float), // LHS stride + lhs_packed_mtx_qa8dx); // LHS packed + + const size_t dst_stride = n * sizeof(float); + const size_t lhs_offset = variant.ukernel.get_lhs_packed_offset(0, k); + const size_t rhs_offset = variant.ukernel.get_rhs_packed_offset(0, k); + const size_t dst_offset = variant.ukernel.get_dst_offset(0, 0, dst_stride); + + const void* lhs_ptr = (const void*)((const char*)lhs_packed_mtx_qa8dx + lhs_offset); + const void* rhs_ptr = (const void*)((const char*)rhs_packed_mtx_qs4cx + rhs_offset); + float* dst_ptr = (float*)((uint8_t*)dst_act_mtx_f32 + dst_offset); + + // Matmul + for (auto _ : state) { + variant.ukernel.run_matmul( + m, n, k, // Dimensions + lhs_ptr, // LHS packed + rhs_ptr, // RHS packed + dst_ptr, // DST + dst_stride, // DST stride (row) + sizeof(float), // DST stride (col) + -FLT_MAX, FLT_MAX // Min and max for the clamp operation + ); + } + + delete[] lhs_packed_mtx_qa8dx; + delete[] rhs_packed_mtx_qs4cx; + delete[] dst_act_mtx_f32; + + delete[] lhs_native_mtx_f32; + delete[] rhs_native_mtx_qs4cx; + delete[] rhs_scales_f32; + } +}; /* kai_matmul_f32_qai8_qsi4 */ + +namespace dotprod { +kai_matmul_ukernel_f32_qa8dxp_qs4cxp ukernel_variants[] = { + {kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + "matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod"}, + {kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + "matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod"}, +}; + +void RegisterBenchmarks(size_t m, size_t n, size_t k) { + kai::bench::matmul_f32_qa8dxp_qs4cxp::kai_matmul_f32_qai8_qsi4 matmul_f32; + if (kai::test::cpu_has_dotprod()) { + for (const auto& variant : dotprod::ukernel_variants) { + ::benchmark::RegisterBenchmark(variant.name, matmul_f32, variant, m, n, k); + } + } +} + +} /* namespace dotprod */ + +namespace i8mm { +kai_matmul_ukernel_f32_qa8dxp_qs4cxp ukernel_variants[] = { + {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + "matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm"}, + {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + "matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm"}, + {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + "matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm"}, + {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + "matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm"}, +}; + +void RegisterBenchmarks(size_t m, size_t n, size_t k) { + kai::bench::matmul_f32_qa8dxp_qs4cxp::kai_matmul_f32_qai8_qsi4 matmul_f32; + if (kai::test::cpu_has_i8mm()) { + for (const auto& variant : i8mm::ukernel_variants) { + ::benchmark::RegisterBenchmark(variant.name, matmul_f32, variant, m, n, k); + } + } +} + +}; /* namespace i8mm */ +}; // namespace kai::bench::matmul_f32_qa8dxp_qs4cxp + +#endif /* !defined(__ARM_FEATURE_DOTPROD) && !defined(__ARM_FEATURE_MATMUL_INT8) */ diff --git a/benchmark/matmul/matmul_f32.hpp b/benchmark/matmul/matmul_f32.hpp index e8aa2dfb806dc5b811cca2669d1a54df69dfeeb2..1245ca8a2c150498afcca0aca063fbaca8ee0c8e 100644 --- a/benchmark/matmul/matmul_f32.hpp +++ b/benchmark/matmul/matmul_f32.hpp @@ -4,200 +4,18 @@ // SPDX-License-Identifier: Apache-2.0 // -#ifndef KAI_BENCHMARK_MATMUL_MATMUL_F32_HPP -#define KAI_BENCHMARK_MATMUL_MATMUL_F32_HPP +#pragma once -#include +#include -#include +namespace kai::bench::matmul_f32_qa8dxp_qs4cxp { -#include "benchmark/matmul/matmul_utils.hpp" -#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h" -#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h" -#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.h" -#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h" -#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.h" -#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h" -#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp_qsi4cxp_interface.h" -#include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h" -#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h" +namespace dotprod { +void RegisterBenchmarks(size_t m, size_t n, size_t k); +}; /* namespace dotprod */ -struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp { - kai_matmul_clamp_f32_qai8dxp_qsi4cxp_ukernel ukernel; - std::string name = {}; -}; +namespace i8mm { +void RegisterBenchmarks(size_t m, size_t n, size_t k); +}; /* namespace i8mm */ -kai_matmul_ukernel_f32_qa8dxp_qs4cxp ukernel_variants[] = { - {kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, - kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, - kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, - kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, - kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, - kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, - kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, - kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, - kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, - "matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod"}, - {kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - "matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod"}, - {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, - kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, - kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, - kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, - kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, - kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, - kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, - kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, - kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, - "matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm"}, - {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, - kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, - kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, - kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, - kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, - kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, - kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, - kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, - kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, - "matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm"}, - {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, - kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, - kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, - kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, - kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, - kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, - kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, - kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, - kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, - "matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm"}, - {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - "matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm"}, -}; - -// Number of micro-kernel variants stored in the array -const size_t num_ukernel_variants = sizeof(ukernel_variants) / sizeof(ukernel_variants[0]); - -struct kai_matmul { - template - void operator()( - benchmark::State& state, kai_matmul_ukernel_f32_qa8dxp_qs4cxp variant, size_t m, size_t n, size_t k) const { - const size_t seed_lhs = 4568; - const size_t seed_rhs = seed_lhs + 4; - - const size_t lhs_native_size_f32 = m * k * sizeof(float); - const size_t rhs_native_size_f32 = n * k * sizeof(float); - const size_t rhs_native_size_qs4cx = n * (k / 2) * sizeof(uint8_t); - const size_t rhs_scales_size_f32 = n * sizeof(float); - - // Allocate the memory - uint8_t* lhs_native_mtx_f32 = new uint8_t[lhs_native_size_f32]; - uint8_t* rhs_native_mtx_f32 = new uint8_t[rhs_native_size_f32]; - uint8_t* rhs_native_mtx_qs4cx = new uint8_t[rhs_native_size_qs4cx]; - uint8_t* rhs_scales_f32 = new uint8_t[rhs_scales_size_f32]; - - fill_uniform_random(m, k, (float*)lhs_native_mtx_f32, seed_lhs); - fill_uniform_random(n, k, (float*)rhs_native_mtx_f32, seed_rhs); - - quant_qs4cx_f32(n, k, (const float*)rhs_native_mtx_f32, rhs_native_mtx_qs4cx, (float*)rhs_scales_f32); - - delete[] rhs_native_mtx_f32; - - // Get the packing parameters - const size_t mr = variant.ukernel.get_mr(); - const size_t nr = variant.ukernel.get_nr(); - const size_t kr = variant.ukernel.get_kr(); - const size_t sr = variant.ukernel.get_sr(); - - // Get the size in bytes for the packed matrices - const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(m, k, mr, kr, sr); - const size_t rhs_packed_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(n, k, nr, kr, sr); - const size_t dst_size = variant.ukernel.get_dst_size(m, n); - - // Allocate the matrices - uint8_t* lhs_packed_mtx_qa8dx = new uint8_t[lhs_packed_size]; - uint8_t* rhs_packed_mtx_qs4cx = new uint8_t[rhs_packed_size]; - uint8_t* dst_act_mtx_f32 = new uint8_t[dst_size]; - - // If the RHS matrix contains constant values, the packing can be performed - // only once - struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params params; - params.lhs_zero_point = 1; - params.rhs_zero_point = 8; - - // RHS packing - kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0( - 1, n, k, nr, kr, sr, // Packing arguments - (const uint8_t*)(rhs_native_mtx_qs4cx), // RHS - NULL, // Bias - (const float*)(rhs_scales_f32), // Scale - rhs_packed_mtx_qs4cx, // RHS packed - 0, ¶ms); - - // LHS packing - kai_run_lhs_quant_pack_qai8dxp_f32( - m, k, mr, kr, sr, 0, // Packing arguments - (const float*)lhs_native_mtx_f32, // LHS - k * sizeof(float), // LHS stride - lhs_packed_mtx_qa8dx); // LHS packed - - const size_t dst_stride = n * sizeof(float); - const size_t lhs_offset = variant.ukernel.get_lhs_packed_offset(0, k); - const size_t rhs_offset = variant.ukernel.get_rhs_packed_offset(0, k); - const size_t dst_offset = variant.ukernel.get_dst_offset(0, 0, dst_stride); - - const void* lhs_ptr = (const void*)((const char*)lhs_packed_mtx_qa8dx + lhs_offset); - const void* rhs_ptr = (const void*)((const char*)rhs_packed_mtx_qs4cx + rhs_offset); - float* dst_ptr = (float*)((uint8_t*)dst_act_mtx_f32 + dst_offset); - - // Matmul - for (auto _ : state) { - variant.ukernel.run_matmul( - m, n, k, // Dimensions - lhs_ptr, // LHS packed - rhs_ptr, // RHS packed - dst_ptr, // DST - dst_stride, // DST stride (row) - sizeof(float), // DST stride (col) - -FLT_MAX, FLT_MAX // Min and max for the clamp operation - ); - } - - delete[] lhs_packed_mtx_qa8dx; - delete[] rhs_packed_mtx_qs4cx; - delete[] dst_act_mtx_f32; - - delete[] lhs_native_mtx_f32; - delete[] rhs_native_mtx_qs4cx; - delete[] rhs_scales_f32; - } -}; /* struct kai_matmul */ - -#endif /* KAI_BENCHMARK_MATMUL_MATMUL_F32_HPP */ +}; // namespace kai::bench::matmul_f32_qa8dxp_qs4cxp diff --git a/benchmark/matmul/matmul_f32_f32p_f32p.cpp b/benchmark/matmul/matmul_f32_f32p_f32p.cpp new file mode 100644 index 0000000000000000000000000000000000000000..64f605b8c7cf99f7d1dd84ddc3f5e1b76bf1938d --- /dev/null +++ b/benchmark/matmul/matmul_f32_f32p_f32p.cpp @@ -0,0 +1,120 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#error This file must be compiler for AArch64, FEAT_SVE2 +#else // Architectural feature check + +#include "benchmark/matmul/matmul_f32_f32p_f32p.hpp" + +#include + +#include "benchmark/matmul/matmul_utils.hpp" +#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p_f32p_interface.h" +#include "kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h" +#include "test/common/cpu_info.hpp" + +namespace kai::bench::matmul_f32_f32p_f32p { + +const size_t seed_lhs = 4568; +const size_t seed_rhs = seed_lhs + 4; +const size_t seed_bias = seed_rhs + 4; + +struct kai_matmul_ukernel_f32_f32p_f32p { + kai_matmul_clamp_f32_f32p_f32p_ukernel ukernel; + std::string name = {}; +}; + +kai_matmul_ukernel_f32_f32p_f32p sme_variants[] = { + {kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + "matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa"}, +}; + +struct kai_matmul_f32_f32p_f32p_sme { + template + void operator()( + benchmark::State& state, kai_matmul_ukernel_f32_f32p_f32p variant, size_t m, size_t n, size_t k) const { + const size_t lhs_size = m * k; + const size_t rhs_size = n * k; + const size_t bias_size = n; + const size_t dst_size = m * n; + + float* lhs = new float[lhs_size]; + float* rhs = new float[rhs_size]; + float* bias = new float[bias_size]; + + fill_uniform_random(m, k, lhs, seed_lhs); + fill_uniform_random(k, n, rhs, seed_rhs); + fill_uniform_random(1, n, bias, seed_bias); + + const size_t mr = variant.ukernel.get_mr(); + const size_t nr = variant.ukernel.get_nr(); + const size_t kr = variant.ukernel.get_kr(); + const size_t sr = variant.ukernel.get_sr(); + + const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_pack_f32p2vlx1_f32_sme(m, k, mr, kr, sr); + float* lhs_packed = new float[lhs_packed_size / sizeof(float)]; + + const size_t rhs_packed_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(n, k); + + float* rhs_packed = new float[rhs_packed_size / sizeof(float)]; + + const size_t lhs_stride = k * sizeof(float); + const size_t rhs_stride = n * sizeof(float); + const size_t dst_stride_row = n * sizeof(float); + const size_t dst_stride_col = sizeof(float); + kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme( + 1, n, k, nr, kr, sr, // Packing arguments + rhs_stride, rhs, bias, NULL, rhs_packed, 0, NULL); + + kai_run_lhs_pack_f32p2vlx1_f32_sme(m, k, mr, kr, sr, 0, lhs, k * sizeof(float), lhs_packed); + + float* dst = new float[dst_size]; + for (auto _ : state) { + // run matmul + variant.ukernel.run_matmul( + m, n, k, // Dimensions + lhs_packed, // LHS + rhs_packed, // RHS packed + dst, // DST + dst_stride_row, // DST stride (row) + dst_stride_col, // DST stride (col) + FLT_MIN, FLT_MAX // Min and max for the clamp operation + ); + } + + delete[] lhs; + delete[] rhs; + delete[] bias; + delete[] rhs_packed; + delete[] lhs_packed; + delete[] dst; + } +}; /* struct kai_matmul_f32_f32p_f32p_sme */ + +void RegisterBenchmarks(size_t m, size_t n, size_t k) { + kai_matmul_f32_f32p_f32p_sme sme_kernel; + if (kai::test::cpu_has_sme2()) { + for (const auto& variant : sme_variants) { + ::benchmark::RegisterBenchmark(variant.name, sme_kernel, variant, m, n, k)->Iterations(2000); + } + } +} +} // namespace kai::bench::matmul_f32_f32p_f32p + +#endif /* defined(__aarch64__) && defined(__ARM_FEATURE_SVE2) */ diff --git a/benchmark/matmul/matmul_f32_f32p_f32p.hpp b/benchmark/matmul/matmul_f32_f32p_f32p.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f05597ce4b397604df7fcb0b2678583d85e2aa5b --- /dev/null +++ b/benchmark/matmul/matmul_f32_f32p_f32p.hpp @@ -0,0 +1,15 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +namespace kai::bench::matmul_f32_f32p_f32p { + +void RegisterBenchmarks(size_t m, size_t n, size_t k); + +}; diff --git a/benchmark/matmul/matmul_utils.hpp b/benchmark/matmul/matmul_utils.hpp index d2c7e0a3331ed48ea54577cd83c7b552046d76f1..a19fdbd1d0512268d8da0de50345c0bd1eb62b16 100644 --- a/benchmark/matmul/matmul_utils.hpp +++ b/benchmark/matmul/matmul_utils.hpp @@ -4,8 +4,7 @@ // SPDX-License-Identifier: Apache-2.0 // -#ifndef KAI_BENCHMARK_MATMUL_MATMUL_UTILS_HPP -#define KAI_BENCHMARK_MATMUL_MATMUL_UTILS_HPP +#pragma once #include #include @@ -82,5 +81,3 @@ static void quant_qs4cx_f32(size_t n, size_t k, const float* rhs_f32, uint8_t* r rhs_scales_f32[row_idx] = recip_scale0; } }; - -#endif /* KAI_BENCHMARK_MATMUL_MATMUL_UTILS_HPP */ diff --git a/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p_f32p_interface.h b/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p_f32p_interface.h new file mode 100644 index 0000000000000000000000000000000000000000..3e40c2f401c60510ccbce41afa4448216988d2d7 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p_f32p_interface.h @@ -0,0 +1,51 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// All micro-kernels variants of the same type share the same interfaces +// In this case, the micro-kernel type is: matmul_clamp_f32_f32p_f32p + +/// Micro-kernel helper functions ("get" methods) +typedef size_t (*kai_matmul_clamp_f32_f32p_f32p_get_m_step_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_f32p_f32p_get_n_step_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_f32p_f32p_get_mr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_f32p_f32p_get_nr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_f32p_f32p_get_kr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_f32p_f32p_get_sr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_f32p_f32p_get_lhs_packed_offset_func_t)(size_t m_idx, size_t k); +typedef size_t (*kai_matmul_clamp_f32_f32p_f32p_get_rhs_packed_offset_func_t)(size_t n_idx, size_t k); +typedef size_t (*kai_matmul_clamp_f32_f32p_f32p_get_dst_offset_func_t)(size_t m_idx, size_t n_idx, size_t dst_stride); +typedef size_t (*kai_matmul_clamp_f32_f32p_f32p_get_dst_size_func_t)(size_t m, size_t n); + +/// Micro-kernel core function ("run" method) +typedef void (*kai_matmul_clamp_f32_f32p_f32p_run_matmul_func_t)( + size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row, + size_t dst_stride_col, float clamp_min, float clamp_max); + +/// Micro-kernel interface +struct kai_matmul_clamp_f32_f32p_f32p_ukernel { + kai_matmul_clamp_f32_f32p_f32p_get_m_step_func_t get_m_step; + kai_matmul_clamp_f32_f32p_f32p_get_n_step_func_t get_n_step; + kai_matmul_clamp_f32_f32p_f32p_get_mr_func_t get_mr; + kai_matmul_clamp_f32_f32p_f32p_get_nr_func_t get_nr; + kai_matmul_clamp_f32_f32p_f32p_get_nr_func_t get_kr; + kai_matmul_clamp_f32_f32p_f32p_get_sr_func_t get_sr; + kai_matmul_clamp_f32_f32p_f32p_get_lhs_packed_offset_func_t get_lhs_packed_offset; + kai_matmul_clamp_f32_f32p_f32p_get_rhs_packed_offset_func_t get_rhs_packed_offset; + kai_matmul_clamp_f32_f32p_f32p_get_dst_offset_func_t get_dst_offset; + kai_matmul_clamp_f32_f32p_f32p_get_dst_size_func_t get_dst_size; + kai_matmul_clamp_f32_f32p_f32p_run_matmul_func_t run_matmul; +}; + +#ifdef __cplusplus +} +#endif