diff --git a/benchmark/README.md b/benchmark/README.md new file mode 100644 index 0000000000000000000000000000000000000000..10429d6627489ef593bea54f3e7768e5b2ebcc3a --- /dev/null +++ b/benchmark/README.md @@ -0,0 +1,65 @@ + + +# KleidiAI benchmark tool + +## Building + +From the kleidiai-root: + +### Linux®-target + +``` +$ mkdir -p build && cd build +$ cmake -DCMAKE_C_COMPILER=/path/to/aarch64-none-linux-gnu-gcc -DCMAKE_CXX_COMPILER=/path/to/aarch64-none-linux-gnu-g++ -DKLEIDIAI_BUILD_BENCHMARK=ON -DCMAKE_BUILD_TYPE=Release ../ +``` + +### Android™-target + +``` +$ mkdir -p build && cd build +$ cmake -DCMAKE_TOOLCHAIN_FILE=/path/to/android-ndk/build/cmake/android.toolchain.cmake -DANDROID_ABI=arm64-v8a -DANDROID_PLATFORM=30 -DKLEIDIAI_BUILD_BENCHMARK=ON -DCMAKE_BUILD_TYPE=Release ../ +``` + +## Usage + +For now, the only kernel that is included in the benchmarking suite is matmul_clamp_f32_qai8dxp_qsi4cxp. +The dimensions of the LHS- and RHS-matrices needs to be specified with the `-m`, `-n` and `-k` options. +The shape of the LHS-matrix is MxK, and the shape of the RHS-matrix is KxN. + +``` +$ ./kleidiai_benchmark -m 13 -n 17 -k 18 +Run on (8 X 1800 MHz CPU s) +Load Average: 10.01, 10.06, 10.06 +----------------------------------------------------------------------------------------------------- +Benchmark Time CPU Iterations +----------------------------------------------------------------------------------------------------- +matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod 123 ns 123 ns 1234567 +matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod 123 ns 123 ns 1234567 +matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm 123 ns 123 ns 1234567 +matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm 123 ns 123 ns 1234567 +matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm 123 ns 123 ns 1234567 +matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm 123 ns 123 ns 1234567 +``` + +### Filtering + +Testcases can be filtered using the `--benchmark_filter` accepts a regex. To run only the dotprod-testcases: +(Note: The measurement results are placeholders) + +``` +$ kleidiai_benchmark --benchmark_filter=dotprod -m 13 -n 17 -k 18 +Run on (8 X 1800 MHz CPU s) +Load Average: 10.09, 10.13, 10.09 +----------------------------------------------------------------------------------------------------- +Benchmark Time CPU Iterations +----------------------------------------------------------------------------------------------------- +matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod 123 ns 123 ns 1234567 +matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod 123 ns 123 ns 1234567 +``` + +This application uses [Google Benchmark](https://github.com/google/benchmark), so all options that Google Benchmark provides can be used. +To list the options provided use the `--help` flag or refer to the [user guide](https://github.com/google/benchmark/blob/main/docs/user_guide.md). diff --git a/benchmark/main.cpp b/benchmark/main.cpp index ef88f39b5255ea3791e29d3255e19ad47f4d3477..5d8e7a32884e1bbfb7b515d23350283421bc71e0 100644 --- a/benchmark/main.cpp +++ b/benchmark/main.cpp @@ -3,21 +3,68 @@ // // SPDX-License-Identifier: Apache-2.0 // -#include + +#include #include #include +#include + +#include "benchmark/matmul/matmul_f32.hpp" + +void print_usage(char* name) { + fprintf(stderr, "Usage:\n"); + fprintf(stderr, "%s -m 13 -n 17 -k 18\n", name); + fprintf(stderr, "\n"); + fprintf(stderr, "For additional options:\n"); + fprintf(stderr, "%s --help\n", name); +} + +int main(int argc, char** argv) { + ::benchmark::Initialize(&argc, argv); + + bool mflag = false; + bool nflag = false; + bool kflag = false; -template -void hello_benchmark(benchmark::State& state, Args&&... args) { - volatile size_t a = 0; - for (auto _ : state) { - for (int i = 0; i < 100; i++) { - a++; + size_t m = 0; + size_t n = 0; + size_t k = 0; + + int opt; + while ((opt = getopt(argc, argv, "m:n:k:")) != -1) { + switch (opt) { + case 'm': + m = atoi(optarg); + mflag = true; + break; + case 'n': + n = atoi(optarg); + nflag = true; + break; + case 'k': + k = atoi(optarg); + kflag = true; + break; + case '?': + // Fallthrough + default: + print_usage(argv[0]); + exit(EXIT_FAILURE); } } -} -BENCHMARK(hello_benchmark); + if (!mflag || !nflag || !kflag) { + print_usage(argv[0]); + exit(EXIT_FAILURE); + } + + kai_matmul matmul_f32; + for (int i = 0; i < num_ukernel_variants; i++) { + ::benchmark::RegisterBenchmark(ukernel_variants[i].name, matmul_f32, ukernel_variants[i], m, n, k); + } -BENCHMARK_MAIN(); + ::benchmark::RunSpecifiedBenchmarks(); + ::benchmark::Shutdown(); + return 0; +} diff --git a/benchmark/matmul/matmul_f32.hpp b/benchmark/matmul/matmul_f32.hpp new file mode 100644 index 0000000000000000000000000000000000000000..48f5ed03fdc3781ddf01f196db0acac2d14396ef --- /dev/null +++ b/benchmark/matmul/matmul_f32.hpp @@ -0,0 +1,203 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#ifndef KAI_BENCHMARK_MATMUL_MATMUL_F32_HPP +#define KAI_BENCHMARK_MATMUL_MATMUL_F32_HPP + +#include + +#include + +#include "benchmark/matmul/matmul_utils.hpp" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp_qsi4cxp_interface.h" +#include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h" + +struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp { + kai_matmul_clamp_f32_qai8dxp_qsi4cxp_ukernel ukernel; + std::string name = {}; +}; + +kai_matmul_ukernel_f32_qa8dxp_qs4cxp ukernel_variants[] = { + {kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + "matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod"}, + {kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + "matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod"}, + {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + "matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm"}, + {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + "matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm"}, + {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + "matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm"}, + {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + "matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm"}, +}; + +// Number of micro-kernel variants stored in the array +const size_t num_ukernel_variants = sizeof(ukernel_variants) / sizeof(ukernel_variants[0]); + +struct kai_matmul { + template + void operator()( + benchmark::State& state, kai_matmul_ukernel_f32_qa8dxp_qs4cxp variant, size_t m, size_t n, size_t k) const { + const size_t seed_lhs = 4568; + const size_t seed_rhs = seed_lhs + 4; + + const size_t lhs_native_size_f32 = m * k * sizeof(float); + const size_t rhs_native_size_f32 = n * k * sizeof(float); + const size_t rhs_native_size_qs4cx = n * (k / 2) * sizeof(uint8_t); + const size_t rhs_scales_size_f32 = n * sizeof(float); + + // Allocate the memory + uint8_t* lhs_native_mtx_f32 = new uint8_t[lhs_native_size_f32]; + uint8_t* rhs_native_mtx_f32 = new uint8_t[rhs_native_size_f32]; + uint8_t* rhs_native_mtx_qs4cx = new uint8_t[rhs_native_size_qs4cx]; + uint8_t* rhs_scales_f32 = new uint8_t[rhs_scales_size_f32]; + + fill_uniform_random(m, k, (float*)lhs_native_mtx_f32, seed_lhs); + fill_uniform_random(n, k, (float*)rhs_native_mtx_f32, seed_rhs); + + quant_qs4cx_f32(n, k, (const float*)rhs_native_mtx_f32, rhs_native_mtx_qs4cx, (float*)rhs_scales_f32); + + delete[] rhs_native_mtx_f32; + + // Get the packing parameters + const size_t mr = variant.ukernel.get_mr(); + const size_t nr = variant.ukernel.get_nr(); + const size_t kr = variant.ukernel.get_kr(); + const size_t sr = variant.ukernel.get_sr(); + + // Get the size in bytes for the packed matrices + const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(m, k, mr, kr, sr); + const size_t rhs_packed_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(n, k, nr, kr, sr); + const size_t dst_size = variant.ukernel.get_dst_size(m, n); + + // Allocate the matrices + uint8_t* lhs_packed_mtx_qa8dx = new uint8_t[lhs_packed_size]; + uint8_t* rhs_packed_mtx_qs4cx = new uint8_t[rhs_packed_size]; + uint8_t* dst_act_mtx_f32 = new uint8_t[dst_size]; + + // If the RHS matrix contains constant values, the packing can be performed + // only once + struct kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0_params params; + params.lhs_zero_point = 1; + params.rhs_zero_point = 8; + + // RHS packing + kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( + 1, n, k, nr, kr, sr, // Packing arguments + (const uint8_t*)(rhs_native_mtx_qs4cx), // RHS + NULL, // Bias + (const float*)(rhs_scales_f32), // Scale + rhs_packed_mtx_qs4cx, // RHS packed + 0, ¶ms); + + // LHS packing + kai_run_lhs_quant_pack_qai8dxp_f32( + m, k, mr, kr, sr, 0, // Packing arguments + (const float*)lhs_native_mtx_f32, // LHS + k * sizeof(float), // LHS stride + lhs_packed_mtx_qa8dx); // LHS packed + + const size_t dst_stride = n * sizeof(float); + const size_t lhs_offset = variant.ukernel.get_lhs_packed_offset(0, k); + const size_t rhs_offset = variant.ukernel.get_rhs_packed_offset(0, k); + const size_t dst_offset = variant.ukernel.get_dst_offset(0, 0, dst_stride); + + const void* lhs_ptr = (const void*)((const char*)lhs_packed_mtx_qa8dx + lhs_offset); + const void* rhs_ptr = (const void*)((const char*)rhs_packed_mtx_qs4cx + rhs_offset); + float* dst_ptr = (float*)((uint8_t*)dst_act_mtx_f32 + dst_offset); + + // Matmul + for (auto _ : state) { + variant.ukernel.run_matmul( + m, n, k, // Dimensions + lhs_ptr, // LHS packed + rhs_ptr, // RHS packed + dst_ptr, // DST + dst_stride, // DST stride (row) + sizeof(float), // DST stride (col) + -FLT_MAX, FLT_MAX // Min and max for the clamp operation + ); + } + + delete[] lhs_packed_mtx_qa8dx; + delete[] rhs_packed_mtx_qs4cx; + delete[] dst_act_mtx_f32; + + delete[] lhs_native_mtx_f32; + delete[] rhs_native_mtx_qs4cx; + delete[] rhs_scales_f32; + } +}; /* struct kai_matmul */ + +#endif /* KAI_BENCHMARK_MATMUL_MATMUL_F32_HPP */ diff --git a/benchmark/matmul/matmul_utils.hpp b/benchmark/matmul/matmul_utils.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d2c7e0a3331ed48ea54577cd83c7b552046d76f1 --- /dev/null +++ b/benchmark/matmul/matmul_utils.hpp @@ -0,0 +1,86 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#ifndef KAI_BENCHMARK_MATMUL_MATMUL_UTILS_HPP +#define KAI_BENCHMARK_MATMUL_MATMUL_UTILS_HPP + +#include +#include +#include +#include +#include +#include + +#define INT4_MIN (-8) +#define INT4_MAX (7) + +static void fill_uniform_random(size_t num_rows, size_t num_cols, float* dst, size_t seed) { + std::srand(seed); + + // Fill the array with random values between -1 and 1 + for (int i = 0; i < num_rows * num_cols; i++) { + dst[i] = (float)((double)std::rand() / RAND_MAX) * 2 - 1; + } +} + +static void quant_qs4cx_f32(size_t n, size_t k, const float* rhs_f32, uint8_t* rhs_qs4cx, float* rhs_scales_f32) { + const size_t dst_stride = (k / 2) * sizeof(int8_t); + + for (size_t row_idx = 0; row_idx < n; ++row_idx) { + const float* src_ptr = rhs_f32 + row_idx * k; + + float max0 = -FLT_MAX; + float min0 = FLT_MAX; + + // Find min/max for each channel + for (size_t k_idx = 0; k_idx < k; ++k_idx) { + const float src0_0 = src_ptr[k_idx]; + + max0 = std::max(src0_0, max0); + min0 = std::min(src0_0, min0); + } + + // Maximum/minimum int8 values + const float qmin = (float)INT4_MIN; + const float qmax = (float)INT4_MAX; + + const float rmin0 = std::min(0.0f, min0); + const float rmax0 = std::max(0.0f, max0); + + const float scale0 = rmin0 == rmax0 ? 1.f : (qmax - qmin) / (rmax0 - rmin0); + + // Reciprocal to quantize + const float recip_scale0 = scale0 ? 1.0f / scale0 : 0.0f; + + uint8_t* dst_ptr = rhs_qs4cx + row_idx * dst_stride; + + // Quantize the channels + for (size_t k_idx = 0; k_idx < k; k_idx += 2) { + const float src0_0 = src_ptr[k_idx + 0]; + const float src0_1 = src_ptr[k_idx + 1]; + + // Scale the values + int32_t v0_s32 = (int32_t)(round(src0_0 * scale0)); + int32_t v1_s32 = (int32_t)(round(src0_1 * scale0)); + + // Maximum/minimum int4 values + v0_s32 = std::clamp(v0_s32, INT4_MIN, INT4_MAX); + v1_s32 = std::clamp(v1_s32, INT4_MIN, INT4_MAX); + + int32_t v0_u8 = (uint8_t)(v0_s32 + 8); + int32_t v1_u8 = (uint8_t)(v1_s32 + 8); + + const uint8_t rhs_v0 = (v1_u8 << 4) | v0_u8; + + dst_ptr[0] = rhs_v0; + dst_ptr += sizeof(uint8_t); + } + + rhs_scales_f32[row_idx] = recip_scale0; + } +}; + +#endif /* KAI_BENCHMARK_MATMUL_MATMUL_UTILS_HPP */