diff --git a/.clang-tidy b/.clang-tidy index 4c0dfbe83ae753610ff4e501aa8f3c2e5098886c..f05ad6327cc11a99d41ea6a33ff884fdb4011c87 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -27,6 +27,8 @@ readability-*, -readability-simplify-boolean-expr, -bugprone-easily-swappable-parameters, -cppcoreguidelines-pro-bounds-pointer-arithmetic, --performance-enum-size +-performance-enum-size, +-llvm-else-after-return, +-readability-else-after-return, ' ... diff --git a/CMakeLists.txt b/CMakeLists.txt index b7e6786d35e204fc7190da8603314cb7a182c627..ca7a72af2afda88c9576c117bd5099bb7c241b8b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -63,11 +63,16 @@ if(KLEIDIAI_BUILD_TESTS) test/common/compare.cpp test/common/matrix_portion.cpp test/common/rect.cpp + test/reference/binary_elementwise.cpp + test/reference/matmul.cpp test/reference/fill.cpp + test/reference/pack.cpp test/reference/quantize.cpp test/reference/reduce.cpp test/reference/round.cpp + + test/tests/matmul_test.cpp ) target_include_directories(kleidiai_test diff --git a/test/reference/matmul.cpp b/test/reference/matmul.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6b2ab0410b2d6f72f2d2688c8738009b4bf505dd --- /dev/null +++ b/test/reference/matmul.cpp @@ -0,0 +1,154 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/reference/matmul.hpp" + +#include +#include +#include + +#include "src/kai_common.h" +#include "test/common/data_format.hpp" +#include "test/common/data_type.hpp" +#include "test/common/int4.hpp" +#include "test/common/memory.hpp" +#include "test/common/printer.hpp" +#include "test/reference/binary_elementwise.hpp" +#include "test/reference/pack.hpp" +#include "test/reference/quantize.hpp" +#include "test/reference/reduce.hpp" + +namespace kai::test { + +namespace { + +/// Matrix multiplication. +/// +/// @tparam T Data type. +/// +/// @param[in] lhs LHS operand data buffer. +/// @param[in] rhs RHS operand data buffer. +/// @param[in] m Output height. +/// @param[in] n Output width. +/// @param[in] k Non-transposed LHS width and non-transposed RHS height. +/// @param[in] lhs_transposed `true` if LHS operand is transposed. +/// @param[in] rhs_transposed `true` if RHS operand is transposed. +/// +/// @return The result data buffer. +template +std::vector matmul_any_type( + const void* lhs, const void* rhs, // + size_t m, size_t n, size_t k, // + bool lhs_transposed, bool rhs_transposed) { + const auto lhs_m_stride = lhs_transposed ? 1 : k; + const auto lhs_k_stride = lhs_transposed ? m : 1; + + const auto rhs_n_stride = rhs_transposed ? k : 1; + const auto rhs_k_stride = rhs_transposed ? 1 : n; + + std::vector dst; + dst.resize(m * n * size_in_bits / 8); + KAI_ASSUME(n * size_in_bits % 8 == 0); + + for (size_t im = 0; im < m; ++im) { + for (size_t in = 0; in < n; ++in) { + T acc{0}; + + for (size_t ik = 0; ik < k; ++ik) { + const auto lhs_value = read_array(lhs, im * lhs_m_stride + ik * lhs_k_stride); + const auto rhs_value = read_array(rhs, in * rhs_n_stride + ik * rhs_k_stride); + acc += lhs_value * rhs_value; + } + + write_array(dst.data(), im * n + in, acc); + } + } + + return dst; +} + +} // namespace + +std::vector matmul_pack_rhs( + const void* data, const void* scales, const void* zero_points, const DataFormat& src_format, + const DataFormat& dst_format, size_t height, size_t width) { + const auto src_dt = src_format.data_type(); + const auto src_qf = src_format.quantization_format(); + + const auto dst_dt = dst_format.data_type(); + const auto dst_qf = dst_format.quantization_format(); + + std::vector tmp_data; + std::vector tmp_scales; + std::vector tmp_zero_points; + + if (src_dt == DataType::QSU4 && src_qf == DataFormat::QuantizationFormat::NONE && // + dst_dt == DataType::QSI4 && dst_qf == DataFormat::QuantizationFormat::PER_ROW) { + // For this specific RHS format conversion: + // + // * 4-bit data is added by 8. + // * Scale is divided by 16. + // * Zero point is accumulation of all values in the same row. + + KAI_ASSUME(zero_points == nullptr); + const int32_t zero_point = 8; + const uint8_t zero_point_i4 = UInt4::pack_u8(UInt4(zero_point), UInt4(zero_point)); + const int32_t row_zero_point = zero_point * static_cast(width); + + KAI_ASSUME(dst_format.subblock_width() > 0); + const auto subblock_width_i32 = static_cast(dst_format.subblock_width()); + const auto subblock_width_f = static_cast(dst_format.subblock_width()); + + tmp_zero_points = reduce_add(data, src_format, height, width, DataFormat(DataType::I32), 0); + tmp_zero_points = sub(tmp_zero_points.data(), DataType::I32, height, 1, &row_zero_point, DataType::I32, 1, 1); + tmp_zero_points = + mul(tmp_zero_points.data(), DataType::I32, height, 1, &subblock_width_i32, DataType::I32, 1, 1); + zero_points = tmp_zero_points.data(); + + tmp_data = add(data, DataType::QSU4, height, width, &zero_point_i4, DataType::QSU4, 1, 1); + data = tmp_data.data(); + + tmp_scales = div(scales, DataType::FP32, height, 1, &subblock_width_f, DataType::FP32, 1, 1); + scales = tmp_scales.data(); + } + + return pack(dst_format, data, scales, zero_points, src_format, height, width); +} + +std::vector matmul( + const void* lhs, const void* lhs_scales, const void* lhs_zero_points, DataType lhs_dt, // + const void* rhs, const void* rhs_scales, const void* rhs_zero_points, DataType rhs_dt, // + DataType dst_dt, // + size_t m, size_t n, size_t k, // + bool lhs_transposed, bool rhs_transposed) { + const auto lhs_h = lhs_transposed ? k : m; + const auto lhs_w = lhs_transposed ? m : k; + + const auto rhs_h = rhs_transposed ? n : k; + const auto rhs_w = rhs_transposed ? k : n; + + std::vector tmp_lhs; + std::vector tmp_rhs; + + if (data_type_is_quantized(lhs_dt)) { + tmp_lhs = dequantize( + lhs, lhs_scales, lhs_zero_points, lhs_dt, DataType::FP32, QuantizationMethod::PER_MATRIX, lhs_h, lhs_w); + lhs = tmp_lhs.data(); + } + + if (data_type_is_quantized(rhs_dt)) { + tmp_rhs = dequantize( + rhs, rhs_scales, rhs_zero_points, rhs_dt, DataType::FP32, QuantizationMethod::PER_ROW, rhs_h, rhs_w); + rhs = tmp_rhs.data(); + } + + KAI_ASSUME(dst_dt == DataType::FP32); + const auto tmp_dst = matmul_any_type(lhs, rhs, m, n, k, lhs_transposed, rhs_transposed); + + return tmp_dst; +} + +} // namespace kai::test diff --git a/test/reference/matmul.hpp b/test/reference/matmul.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7dfca92c033fa8535f072d667b060b9a8f2fd78a --- /dev/null +++ b/test/reference/matmul.hpp @@ -0,0 +1,60 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include + +#include "test/common/data_type.hpp" + +namespace kai::test { + +class DataFormat; + +/// Packs the RHS operand of matrix multiplication. +/// +/// @param[in] data Data buffer. +/// @param[in] scales (Optional) Quantization scales. +/// @param[in] zero_points (Optional) Quantization zero points. +/// @param[in] src_format Data format of the RHS matrix. +/// @param[in] dst_format Data format of the packed RHS matrix. +/// @param[in] height Number of rows. +/// @param[in] width Number of columns. +/// +/// @return The packed RHS matrix. +std::vector matmul_pack_rhs( + const void* data, const void* scales, const void* zero_points, const DataFormat& src_format, + const DataFormat& dst_format, size_t height, size_t width); + +/// Matrix multiplication. +/// +/// @param[in] lhs LHS operand data. +/// @param[in] lhs_scales (Optional) LHS operand quantization scales. +/// @param[in] lhs_zero_points (Optional) LHS operand quantization zero point. +/// @param[in] lhs_dt LHS operand data type. +/// @param[in] dst LHS operand data. +/// @param[in] dst_scales (Optional) LHS operand quantization scales. +/// @param[in] dst_zero_points (Optional) LHS operand quantization zero point. +/// @param[in] dst_dt LHS operand data type. +/// @param[in] dst_dt Output data type. +/// @param[in] m Output height. +/// @param[in] n Output width. +/// @param[in] k Non-transposed LHS width and non-transposed RHS height. +/// @param[in] lhs_transposed `true` if LHS operand is transposed. +/// @param[in] rhs_transposed `true` if RHS operand is transposed. +/// +/// @return The result data buffer. +std::vector matmul( + const void* lhs, const void* lhs_scales, const void* lhs_zero_points, DataType lhs_dt, // + const void* rhs, const void* rhs_scales, const void* rhs_zero_points, DataType rhs_dt, // + DataType dst_dt, // + size_t m, size_t n, size_t k, // + bool lhs_transposed, bool rhs_transposed); + +} // namespace kai::test diff --git a/test/reference/pack.cpp b/test/reference/pack.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8a596ed221d77e31391fa2cf025ec4621be2355d --- /dev/null +++ b/test/reference/pack.cpp @@ -0,0 +1,206 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/reference/pack.hpp" + +#include +#include +#include +#include +#include + +#include "src/kai_common.h" +#include "test/common/data_format.hpp" +#include "test/common/data_type.hpp" +#include "test/reference/quantize.hpp" + +namespace kai::test { + +namespace { + +/// Packs the matrix from raw to quantized format. +template +std::vector pack_quant_per_row( + const void* src, size_t height, size_t width, size_t block_height, size_t block_width) { + const auto num_groups = (height + block_height - 1) / block_height; + const auto group_num_blocks = (width + block_width - 1) / block_width; + + const auto group_zero_points_bytes = block_height * sizeof(ZeroPoint); + const auto group_scales_bytes = block_height * sizeof(Scale); + const auto block_data_bytes = block_height * block_width * sizeof(Output); + const auto group_bytes = group_zero_points_bytes + group_num_blocks * block_data_bytes + group_scales_bytes; + const auto dst_bytes = num_groups * group_bytes; + + std::vector dst; + dst.resize(dst_bytes); + + const auto* src_ptr = reinterpret_cast(src); + auto* dst_ptr = dst.data(); + + std::vector scales; + scales.resize(block_height); + + std::vector zero_points; + zero_points.resize(block_height); + + for (size_t group_no = 0; group_no < num_groups; ++group_no) { + // Finds the range of values and calculates the quantization information. + for (size_t y = 0; y < block_height; ++y) { + auto min_value = std::numeric_limits::max(); + auto max_value = std::numeric_limits::lowest(); + + for (size_t x = 0; x < width; ++x) { + const auto value = src_ptr[(group_no * block_height + y) * width + x]; + + if (value < min_value) { + min_value = value; + } + + if (value > max_value) { + max_value = value; + } + } + + std::tie(scales[y], zero_points[y]) = get_qai8_scale_zero_point_from_range(min_value, max_value); + } + + // Packs the zero points. + memcpy(dst_ptr, zero_points.data(), group_zero_points_bytes); + dst_ptr += group_zero_points_bytes; + + // Quantizes and packs the data. + for (size_t x_block = 0; x_block < group_num_blocks; ++x_block) { + for (size_t block_y = 0; block_y < block_height; ++block_y) { + for (size_t block_x = 0; block_x < block_width; ++block_x) { + const auto value = + src_ptr[(group_no * block_height + block_y) * width + x_block * block_width + block_x]; + const auto qvalue = quantize_i8_fp32(value, scales[block_y], zero_points[block_y]); + *reinterpret_cast(dst_ptr) = qvalue; + ++dst_ptr; + } + } + } + + // Packs the scales. + memcpy(dst_ptr, scales.data(), group_scales_bytes); + dst_ptr += group_scales_bytes; + } + + KAI_ASSERT(reinterpret_cast(dst_ptr) - reinterpret_cast(dst.data()) == dst_bytes); + + return dst; +} + +/// Packs the matrix with per-row quantized format. +/// +/// The source matrix is per-row quantized with separate quantization scale and zero-points data buffer. +/// The destination data is per-row quantized with blocking and embedded quantization information. +std::vector pack_per_row_qs4( + const void* src, const void* scales, const void* zero_points, size_t height, size_t width, size_t block_height, + size_t block_width, size_t subblock_height, size_t subblock_width) { + // Number of elements in a sub-block in vertical and horizontal axes. + const auto num_element_rows = subblock_height; + const auto num_element_cols = subblock_width; + const auto src_element_row_stride = width / 2; + + // Number of sub-blocks in a block in vertical and horizontal axes. + const auto num_subblock_rows = block_height / subblock_height; + const auto num_subblock_cols = block_width / subblock_width; + const auto src_subblock_col_stride = subblock_width / 4; + const auto src_subblock_row_stride = subblock_height * width / 2; + + // Number of blocks in the matrix in vertical and horizontal axes. + const auto num_block_rows = (height + block_height - 1) / block_height; + const auto num_block_cols = (width + block_width - 1) / block_width; + const auto src_block_col_stride = block_width / 2; + const auto src_block_row_stride = block_height * width / 2; + + const auto dst_block_row_scales_bytes = block_height * sizeof(float); + const auto dst_block_row_zero_points_bytes = block_height * sizeof(int32_t); + const auto dst_block_row_data_bytes = num_block_cols * block_height * block_width / 2; + const auto dst_bytes = + num_block_rows * (dst_block_row_zero_points_bytes + dst_block_row_data_bytes + dst_block_row_scales_bytes); + + std::vector dst; + dst.resize(dst_bytes); + + const auto* src_ptr = reinterpret_cast(src); + const auto* scales_ptr = reinterpret_cast(scales); + const auto* zero_points_ptr = reinterpret_cast(zero_points); + auto* dst_ptr = dst.data(); + + for (size_t block_row = 0; block_row < num_block_rows; ++block_row) { + if (zero_points_ptr != nullptr) { + memcpy(dst_ptr, zero_points_ptr + block_row * block_height, dst_block_row_zero_points_bytes); + } + + dst_ptr += dst_block_row_zero_points_bytes; + + for (size_t block_col = 0; block_col < num_block_cols; ++block_col) { + for (size_t subblock_col = 0; subblock_col < num_subblock_cols; ++subblock_col) { + for (size_t subblock_row = 0; subblock_row < num_subblock_rows; ++subblock_row) { + for (size_t element_col = 0; element_col < num_element_cols / 4; ++element_col) { + for (size_t element_row = 0; element_row < num_element_rows; ++element_row) { + const auto byte_lo = src_ptr[ // + block_row * src_block_row_stride + block_col * src_block_col_stride + + subblock_row * src_subblock_row_stride + subblock_col * src_subblock_col_stride + + element_row * src_element_row_stride + element_col]; + const auto byte_hi = src_ptr[ // + block_row * src_block_row_stride + block_col * src_block_col_stride + + subblock_row * src_subblock_row_stride + subblock_col * src_subblock_col_stride + + element_row * src_element_row_stride + element_col + block_width / 4]; + + const auto packed_byte0 = (byte_lo & 0x0F) | (byte_hi << 4); + const auto packed_byte1 = (byte_lo >> 4) | (byte_hi & 0xF0); + + dst_ptr[0] = packed_byte0; // ^ 0x88; + dst_ptr[1] = packed_byte1; // ^ 0x88; + dst_ptr += 2; + } + } + } + } + } + + if (scales_ptr != nullptr) { + memcpy(dst_ptr, scales_ptr + block_row * block_height, dst_block_row_scales_bytes); + } + dst_ptr += dst_block_row_scales_bytes; + } + + KAI_ASSERT(reinterpret_cast(dst_ptr) - reinterpret_cast(dst.data()) == dst_bytes); + + return dst; +} + +} // namespace + +std::vector pack( + const DataFormat& dst_format, const void* src, const void* scales, const void* zero_points, + const DataFormat& src_format, size_t height, size_t width) { + const auto dst_dt = dst_format.data_type(); + const auto dst_qf = dst_format.quantization_format(); + const auto src_dt = src_format.data_type(); + const auto src_qf = src_format.quantization_format(); + + if (src_qf == DataFormat::QuantizationFormat::NONE && dst_qf == DataFormat::QuantizationFormat::PER_ROW) { + if (dst_dt == DataType::QAI8 && src_dt == DataType::FP32 && dst_format.scale_data_type() == DataType::FP32 && + dst_format.zero_point_data_type() == DataType::I32) { + return pack_quant_per_row( + src, height, width, dst_format.block_height(), dst_format.block_width()); + } else if ( + dst_dt == DataType::QSI4 && src_dt == DataType::QSU4 && dst_format.scale_data_type() == DataType::FP32 && + dst_format.zero_point_data_type() == DataType::I32) { + return pack_per_row_qs4( + src, scales, zero_points, height, width, dst_format.block_height(), dst_format.block_width(), + dst_format.subblock_height(), dst_format.subblock_width()); + } + } + + KAI_ERROR("Unsupported operation!"); +} + +} // namespace kai::test diff --git a/test/reference/pack.hpp b/test/reference/pack.hpp new file mode 100644 index 0000000000000000000000000000000000000000..43362009f5debff9fa5695ad1f84d8ed2c65852c --- /dev/null +++ b/test/reference/pack.hpp @@ -0,0 +1,28 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include + +namespace kai::test { + +class DataFormat; + +/// Packs the matrix. +/// +/// @param[in] dst_format Data format of the destination matrix. +/// @param[in] src Data buffer of the source matrix. +/// @param[in] src_format Data format of the source matrix. +/// @param[in] height Number of rows of the source matrix. +/// @param[in] width Number of columns of the source matrix. +std::vector pack( + const DataFormat& dst_format, const void* src, const void* scales, const void* zero_points, + const DataFormat& src_format, size_t height, size_t width); + +} // namespace kai::test diff --git a/test/sample.cpp b/test/sample.cpp deleted file mode 100644 index d8e63383f5bf4c19ad8942dd767d848e57570724..0000000000000000000000000000000000000000 --- a/test/sample.cpp +++ /dev/null @@ -1,10 +0,0 @@ -// -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates -// -// SPDX-License-Identifier: Apache-2.0 -// -#include - -TEST(SampleSuite, SampleTest) { - EXPECT_EQ(1 + 2, 3); -} diff --git a/test/tests/matmul_test.cpp b/test/tests/matmul_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..dcafa8e02d98c19abc996392b1d793954a94cff6 --- /dev/null +++ b/test/tests/matmul_test.cpp @@ -0,0 +1,444 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/reference/matmul.hpp" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/kai_common.h" +#include "test/common/compare.hpp" +#include "test/common/data_format.hpp" +#include "test/common/data_type.hpp" +#include "test/common/matrix_portion.hpp" +#include "test/common/printer.hpp" +#include "test/reference/fill.hpp" +#include "test/reference/pack.hpp" + +namespace kai::test { + +// NOLINTBEGIN(misc-non-private-member-variables-in-classes) + +/// Matrix multiplication method. +struct MatMulMethod { + size_t m0; ///< Block size in M dimension. + size_t n0; ///< Block size in N dimension. + size_t k0; ///< Block size in K dimension. + + bool lhs_transposed; ///< LHS matrix is transposed. + bool rhs_transposed; ///< RHS matrix is transposed. + + DataFormat dst_format; ///< Data format of the destination matrix. + DataFormat lhs_format; ///< Data format of the LHS matrix. + DataFormat packed_lhs_format; ///< Data format of the packed LHS matrix. + DataFormat rhs_format; ///< Data format of the RHS matrix. + DataFormat packed_rhs_format; ///< Data for mat of the packed RHS matrix. + + /// Gets the offset in bytes of the LHS matrix. + /// + /// @param[in] m_idx Coordinate of the matrix in M dimension. + /// @param[in] stride Row stride in bytes. + /// + /// @return The offset in bytes. + std::function fn_get_lhs_offset; + + /// Gets the size in bytes of the packed LHS matrix. + /// + /// @param[in] m Size of the matrix in M dimension. + /// @param[in] k Size of the matrix in K dimension. + /// + /// @return The size in bytes. + std::function fn_get_packed_lhs_size; + + /// Gets the offset in bytes of the packed LHS matrix. + /// + /// @param[in] m_idx Coordinate of the matrix in M dimension. + /// @param[in] k Size of the matrix in K dimension. + /// + /// @return The offset in bytes. + std::function fn_get_packed_lhs_offset; + + /// Preprocesses the LHS matrix. + /// + /// @param[in] m Size of the matrix in M dimension. + /// @param[in] k Size of the matrix in K dimension. + /// @param[in] lhs LHS matrix data buffer. + /// @param[in] lhs_row_stride Row stride in bytes of the LHS matrix. + /// @param[out] packed_lhs Packed LHS matrix data buffer. + std::function fn_pack_lhs; + + /// Gets the offset in bytes of the RHS matrix. + /// + /// @param[in] n_idx Coordinate of the matrix in N dimension. + /// @param[in] stride Row stride in bytes. + /// + /// @return The offset in bytes. + std::function fn_get_rhs_offset; + + /// Gets the size in bytes of the packed RHS matrix. + /// + /// @param[in] n Size of the matrix in N dimension. + /// @param[in] k Size of the matrix in K dimension. + /// @param[in] block_height Block height. + /// @param[in] block_width Block width. + /// + /// @return The size in bytes. + std::function fn_get_packed_rhs_size; + + /// Gets the offset in bytes of the packed RHS matrix. + /// + /// @param[in] n_idx Coordinate of the matrix in N dimension. + /// @param[in] k Size of the matrix in K dimension. + /// @param[in] block_height Block height. + /// @param[in] block_width Block width. + /// + /// @return The offset in bytes. + std::function fn_get_packed_rhs_offset; + + /// Performs matrix multiplication. + /// + /// @param[in] m Size of the matrix in M dimension. + /// @param[in] n Size of the matrix in N dimension. + /// @param[in] k Size of the matrix in K dimension. + /// @param[in] lhs_p Packed LHS data buffer. + /// @param[in] rhs_p Packed RHS data buffer. + /// @param[out] dst Output data buffer. + /// @param[in] dst_stride_row Output row stride. + /// @param[in] dst_stride_col Output column stride. + /// @param[in] scalar_min Lower bound of the output data. + /// @param[in] scalar_max Upper bound of the output data. + std::function + fn_main; + + /// Gets a value indicating whether pre-processing the RHS matrix is needed. + [[nodiscard]] bool is_pack_rhs_needed() const { + return false; + } + + /// Preprocesses the RHS matrix. + /// + /// @param[in] n Size of the matrix in N dimension. + /// @param[in] k Size of the matrix in K dimension. + /// @param[in] rhs RHS data buffer. + /// @param[in] rhs_row_stride RHS row stride. + /// @param[in] bias Bias data buffer. + /// @param[in] scale Quantization scales data buffer. + /// @param[out] packed_rhs Packed RHS data buffer. + void pack_rhs( + size_t n, size_t k, const void* rhs, size_t rhs_row_stride, const void* bias, const void* scale, + void* packed_rhs) const { + KAI_UNUSED(n); + KAI_UNUSED(k); + KAI_UNUSED(rhs); + KAI_UNUSED(rhs_row_stride); + KAI_UNUSED(bias); + KAI_UNUSED(scale); + KAI_UNUSED(packed_rhs); + + KAI_ERROR("RHS pre-processing is not supported!"); + } +}; + +// NOLINTEND(misc-non-private-member-variables-in-classes) + +/// List of supported matrix multiplication methods. +static const std::array matmul_methods = { + MatMulMethod{ + .m0 = 4, + .n0 = 4, + .k0 = 32, + + .lhs_transposed = false, + .rhs_transposed = true, + + .dst_format = DataFormat(DataType::FP32), + .lhs_format = DataFormat(DataType::FP32), + .packed_lhs_format = + DataFormat(DataType::QAI8, 4, 8, DataFormat::QuantizationFormat::PER_ROW, DataType::FP32, DataType::I32), + .rhs_format = DataFormat(DataType::QSU4), + .packed_rhs_format = DataFormat( + DataType::QSI4, 4, 32, DataFormat::QuantizationFormat::PER_ROW, DataType::FP32, DataType::I32, 1, 16), + + .fn_get_lhs_offset = nullptr, + .fn_get_packed_lhs_size = nullptr, + .fn_get_packed_lhs_offset = nullptr, + .fn_pack_lhs = nullptr, + + .fn_get_rhs_offset = nullptr, + .fn_get_packed_rhs_size = nullptr, + .fn_get_packed_rhs_offset = nullptr, + + .fn_main = nullptr, + }, +}; + +/// Matrix multiplication shape. +struct MatMulShape { + size_t m; ///< LHS height. + size_t n; ///< RHS width. + size_t k; ///< LHS width and RHS height. +}; + +/// Matrix multiplication test information. +using MatMulTestParams = std::tuple; + +/// Prints the test information. +void PrintTo(const MatMulTestParams& param, std::ostream* os) { + const auto& [shape, method_no, portion] = param; + + *os << "m: " << shape.m << ", n: " << shape.n << ", k: " << shape.k << ", method_no: " << method_no + << ", portion: { start_row: " << portion.start_row() << ", start_col: " << portion.start_col() + << ", height: " << portion.height() << ", width: " << portion.width() << "}"; +} + +/// Matrix multiplication test fixture. +class MatMulTest : public testing::TestWithParam { +private: + /// Unique ID: m, n, k, method_id. + using TestDataId = std::tuple; + +protected: + /// Cached test data that is shared between multiple test case. + struct TestData { + std::vector lhs{}; ///< LHS operand. + std::vector ref_packed_lhs{}; ///< Reference packed LHS. + std::vector rhs{}; ///< RHS operand. + std::vector rhs_scales{}; ///< RHS per-row quantization scales. + std::vector ref_packed_rhs{}; ///< Reference packed RHS. + std::vector ref_dst{}; ///< Reference output. + }; + + /// Gets the test data for the current test case. + static const TestData& test_data() { + const auto& [info, method_no, portion] = GetParam(); + const TestDataId data_id{info.m, info.n, info.k, method_no}; + + // If the test data is already available, returns it. + const auto data_it = _data.find(data_id); + + if (data_it != _data.end()) { + return data_it->second; + } + + // Generates the test data. + const auto& method = matmul_methods.at(method_no); + + const auto lhs_h = method.lhs_transposed ? info.k : info.m; + const auto lhs_w = method.lhs_transposed ? info.m : info.k; + auto lhs = fill_matrix_random(lhs_h, lhs_w, method.lhs_format, 0); + auto ref_packed_lhs = + pack(method.packed_lhs_format, lhs.data(), nullptr, nullptr, method.lhs_format, lhs_h, lhs_w); + + const auto rhs_h = method.rhs_transposed ? info.n : info.k; + const auto rhs_w = method.rhs_transposed ? info.k : info.n; + auto rhs = fill_matrix_random(rhs_h, rhs_w, method.rhs_format, 1); + + std::vector rhs_scales; + if (data_type_is_quantized(method.rhs_format.data_type()) && + method.rhs_format.quantization_format() == DataFormat::QuantizationFormat::NONE) { + rhs_scales = fill_matrix_random(rhs_h, 1, DataFormat(DataType::FP32), 2); + } + + auto packed_rhs = matmul_pack_rhs( + rhs.data(), !rhs_scales.empty() ? rhs_scales.data() : nullptr, nullptr, method.rhs_format, + method.packed_rhs_format, rhs_h, rhs_w); + + KAI_ASSUME(method.lhs_format.is_raw()); + KAI_ASSUME(method.rhs_format.is_raw()); + KAI_ASSUME(method.dst_format.is_raw()); + auto ref_dst = matmul( + lhs.data(), nullptr, nullptr, method.lhs_format.data_type(), // + rhs.data(), rhs_scales.data(), nullptr, method.rhs_format.data_type(), // + method.dst_format.data_type(), // + info.m, info.n, info.k, method.lhs_transposed, method.rhs_transposed); + + const auto& data = _data[data_id] = { + .lhs = std::move(lhs), + .ref_packed_lhs = std::move(ref_packed_lhs), + .rhs = std::move(rhs), + .rhs_scales = std::move(rhs_scales), + .ref_packed_rhs = std::move(packed_rhs), + .ref_dst = std::move(ref_dst), + }; + + return data; + } + +private: + static std::map _data; +}; + +std::map MatMulTest::_data; + +/// Tests the LHS packing kernel. +TEST_P(MatMulTest, PackedLhs) { + const auto& [info, method_no, portion] = GetParam(); + const auto& data = test_data(); + const auto& method = matmul_methods.at(method_no); + + if (method.fn_pack_lhs == nullptr) { + GTEST_SKIP(); + } + + const auto lhs_h = method.lhs_transposed ? info.k : info.m; + const auto lhs_w = method.lhs_transposed ? info.m : info.k; + + const auto rect = portion.compute_portion( + lhs_h, lhs_w, method.packed_lhs_format.scheduler_block_height(lhs_h), + method.packed_lhs_format.scheduler_block_width(lhs_w)); + + if (rect.height() == 0 || rect.width() == 0) { + GTEST_SKIP(); + } + + const auto ref_lhs_row_stride = method.lhs_format.default_row_stride(lhs_w); + + const auto packed_lhs_size = method.fn_get_packed_lhs_size(info.m, info.k); + const auto ref_packed_lhs_size = method.packed_lhs_format.default_size_in_bytes(lhs_h, lhs_w); + ASSERT_EQ(packed_lhs_size, ref_packed_lhs_size); + + const auto lhs_offset = method.fn_get_lhs_offset(rect.start_row(), ref_lhs_row_stride); + const auto ref_lhs_offset = method.lhs_format.default_offset_in_bytes(rect.start_row(), 0, lhs_w); + ASSERT_EQ(lhs_offset, ref_lhs_offset); + + const auto packed_lhs_offset = method.fn_get_packed_lhs_offset(rect.start_row(), info.k); + const auto ref_packed_lhs_offset = method.packed_lhs_format.default_offset_in_bytes(rect.start_row(), 0, lhs_w); + ASSERT_EQ(packed_lhs_offset, ref_packed_lhs_offset); + + std::vector packed_lhs; + packed_lhs.resize(packed_lhs_size); + method.fn_pack_lhs( + rect.height(), rect.width(), data.lhs.data() + lhs_offset, ref_lhs_row_stride, + packed_lhs.data() + packed_lhs_offset); + + DefaultMismatchHandler handler(0, 0.0001, 0, 0.001); + const auto success = + compare(packed_lhs.data(), data.ref_packed_lhs.data(), method.packed_lhs_format, lhs_h, lhs_w, rect, handler); + ASSERT_TRUE(success); +} + +/// Tests the RHS packing kernel. +TEST_P(MatMulTest, PackedRhs) { + const auto& [info, method_no, portion] = GetParam(); + const auto& data = test_data(); + const auto& method = matmul_methods.at(method_no); + + if (!method.is_pack_rhs_needed()) { + GTEST_SKIP(); + } + + const auto rhs_h = method.rhs_transposed ? info.n : info.k; + const auto rhs_w = method.rhs_transposed ? info.k : info.n; + + const auto rect = portion.compute_portion( + rhs_h, rhs_w, method.packed_rhs_format.scheduler_block_height(rhs_h), + method.packed_rhs_format.scheduler_block_width(rhs_w)); + + if (rect.height() == 0 || rect.width() == 0) { + GTEST_SKIP(); + } + + const auto ref_rhs_row_stride = method.rhs_format.default_row_stride(rhs_w); + + const auto rhs_offset = method.fn_get_rhs_offset(rect.start_row(), ref_rhs_row_stride); + const auto ref_rhs_offset = method.rhs_format.default_offset_in_bytes(rect.start_row(), rect.start_col(), rhs_w); + ASSERT_EQ(rhs_offset, ref_rhs_offset); + + const auto packed_rhs_size = method.fn_get_packed_rhs_size( + rhs_h, rhs_w, method.packed_rhs_format.block_height(), method.packed_rhs_format.block_width()); + const auto ref_packed_rhs_size = method.packed_rhs_format.default_size_in_bytes(rhs_h, rhs_w); + ASSERT_EQ(packed_rhs_size, ref_packed_rhs_size); + + const auto packed_rhs_offset = method.fn_get_packed_rhs_offset( + rect.start_row(), rhs_w, method.packed_rhs_format.block_height(), method.packed_rhs_format.block_width()); + const auto ref_packed_rhs_offset = + method.packed_rhs_format.default_offset_in_bytes(rect.start_row(), rect.start_col(), rhs_w); + ASSERT_EQ(packed_rhs_offset, ref_packed_rhs_offset); + + const auto ref_rhs_scales_offset = + rect.start_row() * data_type_size_in_bits(method.packed_rhs_format.scale_data_type()) / 8; + + std::vector packed_rhs; + packed_rhs.resize(packed_rhs_size); + + method.pack_rhs( + rect.height(), rect.width(), data.rhs.data() + rhs_offset, ref_rhs_row_stride, nullptr, + !data.rhs_scales.empty() ? data.rhs_scales.data() + ref_rhs_scales_offset : nullptr, + packed_rhs.data() + packed_rhs_offset); + + DefaultMismatchHandler handler(0, 0.0001, 0, 0.001); + const auto success = + compare(packed_rhs.data(), data.ref_packed_rhs.data(), method.packed_rhs_format, rhs_h, rhs_w, rect, handler); + ASSERT_TRUE(success); +} + +/// Tests the output. +TEST_P(MatMulTest, Output) { + const auto& [info, method_no, portion] = GetParam(); + const auto& data = test_data(); + const auto& method = matmul_methods.at(method_no); + + if (method.fn_main == nullptr) { + GTEST_SKIP(); + } + + const auto rect = portion.compute_portion(info.m, info.n, method.m0, method.n0); + + if (rect.height() == 0 || rect.width() == 0) { + GTEST_SKIP(); + } + + const auto ref_dst_row_stride = method.dst_format.default_row_stride(info.n); + const auto ref_dst_col_stride = data_type_size_in_bits(method.dst_format.data_type()) / 8; + + const auto ref_packed_lhs_offset = method.packed_lhs_format.default_offset_in_bytes( + method.lhs_transposed ? 0 : rect.start_row(), method.lhs_transposed ? rect.start_row() : 0, + method.lhs_transposed ? info.m : info.k); + const auto ref_packed_rhs_offset = method.packed_rhs_format.default_offset_in_bytes( + method.rhs_transposed ? rect.start_col() : 0, method.rhs_transposed ? 0 : rect.start_col(), + method.rhs_transposed ? info.k : info.n); + const auto ref_dst_offset = method.dst_format.default_offset_in_bytes(rect.start_row(), rect.start_col(), info.n); + + std::vector dst; + dst.resize(method.dst_format.default_size_in_bytes(info.m, info.n)); + + method.fn_main( + rect.height(), rect.width(), info.k, data.ref_packed_lhs.data() + ref_packed_lhs_offset, + data.ref_packed_rhs.data() + ref_packed_rhs_offset, reinterpret_cast(dst.data() + ref_dst_offset), + ref_dst_row_stride, ref_dst_col_stride, std::numeric_limits::lowest(), + std::numeric_limits::max()); + + DefaultMismatchHandler handler(0, 0.1, 0, 0.05); + const auto success = compare(dst.data(), data.ref_dst.data(), method.dst_format, info.m, info.n, rect, handler); + ASSERT_TRUE(success); +} + +INSTANTIATE_TEST_SUITE_P( + MatMul, MatMulTest, + testing::Combine( + testing::Values( + MatMulShape{4, 4, 32}, // + MatMulShape{12, 16, 64}), + testing::Range(0, matmul_methods.size()), + testing::Values( + MatrixPortion(0, 0, 1, 1), // Full matrix. + MatrixPortion(0, 0, 0.25, 0.25), // Top-left corner. + MatrixPortion(0.75, 0.75, 1, 1) // Bottom-right corner. + ))); + +} // namespace kai::test