diff --git a/CMakeLists.txt b/CMakeLists.txt index 5ed083105bd7b0904bc81984f1c12d5fb9828ad9..d2f4b5feced58d4e0b74ae9232d9cb71aae588e9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -446,6 +446,8 @@ if(KLEIDIAI_BUILD_BENCHMARK) benchmark/main.cpp benchmark/matmul/matmul_registry.cpp ) + set_source_files_properties(benchmark/matmul/matmul_registry.cpp + PROPERTIES COMPILE_OPTIONS "-march=armv8-a+sve2${KLEIDIAI_INTERNAL_EXTRA_ARCH}") target_link_libraries(kleidiai_benchmark PRIVATE diff --git a/benchmark/matmul/matmul_benchmark_logic.hpp b/benchmark/matmul/matmul_benchmark_logic.hpp index cff73a2d26a2852f332d3cf08c21054357addb59..e3947f464dfe90f8cad604b491b9aca1dc8a7405 100644 --- a/benchmark/matmul/matmul_benchmark_logic.hpp +++ b/benchmark/matmul/matmul_benchmark_logic.hpp @@ -12,6 +12,7 @@ #include #include +#include "kai/kai_common.h" #include "matmul_interface.hpp" #include "matmul_runner.hpp" @@ -69,9 +70,15 @@ void kai_benchmark_matmul( } // Create sufficiently large buffers - const size_t lhs_size = m * k * sizeof(uint64_t); - const size_t rhs_size = n * k * sizeof(uint64_t); - const size_t dst_size = m * n * sizeof(uint32_t); + size_t lhs_size = m * k * sizeof(uint64_t); + size_t rhs_size = n * k * sizeof(uint64_t); + size_t dst_size = m * n * sizeof(uint32_t); + + if (test::cpu_has_sme() || test::cpu_has_sme2()) { + lhs_size *= kai_get_sme_vector_length_u32(); + rhs_size *= kai_get_sme_vector_length_u32(); + dst_size *= kai_get_sme_vector_length_u32(); + } const Buffer lhs(lhs_size); const Buffer rhs(rhs_size); diff --git a/benchmark/matmul/matmul_runner.hpp b/benchmark/matmul/matmul_runner.hpp index 0f04cd65b29d17102ec06508756a8090f23f05ca..85c0496991742dcc4ba56c207ee9d729d7673b3e 100644 --- a/benchmark/matmul/matmul_runner.hpp +++ b/benchmark/matmul/matmul_runner.hpp @@ -46,7 +46,7 @@ public: n_ = n; k_ = k; - lhs_stride_ = k_ * data_type_size_in_bits(dst_type_); + lhs_stride_ = k_ * data_type_size_in_bits(dst_type_) / 8; dst_stride_row_ = n_ * data_type_size_in_bits(dst_type_) / 8; dst_stride_col_ = data_type_size_in_bits(dst_type_) / 8; }