diff --git a/CMakeLists.txt b/CMakeLists.txt index 76e07cfc3b9e0ecb4e16aa29b7618edc61ba30b9..9401f23f9d6a099f6f4413a53e55be2e2bc66858 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -69,11 +69,16 @@ if(KLEIDIAI_BUILD_TESTS) test/common/data_format.cpp test/common/printer.cpp test/common/matrix_portion.cpp + test/common/rect.cpp test/common/compare.cpp + test/common/int4.cpp test/reference/fill.cpp test/reference/pack.cpp test/reference/quantize.cpp test/reference/round.cpp + test/reference/matmul.cpp + test/reference/binary_elementwise.cpp + test/reference/reduce.cpp test/tests/matmul_test.cpp ) diff --git a/src/kai_common.h b/src/kai_common.h index 47462b4e2914e51ee5d719e60d89617463283573..1d8eab9bf2e3c3dc2d4aed45e3aeea6b463a1d38 100644 --- a/src/kai_common.h +++ b/src/kai_common.h @@ -12,10 +12,10 @@ extern "C" { #include #include -#define KAI_ERROR(msg) \ - do { \ - fprintf(stderr, "%s", msg); \ - exit(EXIT_FAILURE); \ +#define KAI_ERROR(msg) \ + do { \ + fprintf(stderr, "%s (%s:%d)", msg, __FILE__, __LINE__); \ + exit(EXIT_FAILURE); \ } while (0) #define KAI_ASSERT_MSG(cond, msg) \ diff --git a/test/common/compare.cpp b/test/common/compare.cpp index b81ca52b308b546751c68c4fc4c34b9e28b3b920..ddc934b31ecd85ba7bb21bb2fe140ccd1fd71c8c 100644 --- a/test/common/compare.cpp +++ b/test/common/compare.cpp @@ -13,9 +13,11 @@ #include "src/kai_common.h" #include "test/common/data_format.hpp" +#include "test/common/int4.hpp" #include "test/common/logging.hpp" -#include "test/common/matrix_portion.hpp" +#include "test/common/memory.hpp" #include "test/common/printer.hpp" +#include "test/common/rect.hpp" namespace kai::test { @@ -38,30 +40,58 @@ std::tuple calculate_error(T imp, T ref) { return {abs_error, rel_error}; } +/// Compares matrices with per-row quantization. +template +bool compare_raw( + const void* imp_data, const void* ref_data, size_t full_height, size_t full_width, const Rect& rect, + MismatchHandler& handler) { + for (size_t y = 0; y < full_height; ++y) { + for (size_t x = 0; x < full_width; ++x) { + const auto in_roi = + y >= rect.start_row() && y < rect.end_row() && x >= rect.start_col() && x < rect.end_col(); + + const auto imp_value = read_array(imp_data, y * full_width + x); + const auto ref_value = in_roi ? read_array(ref_data, y * full_width + x) : 0; + + const auto [abs_err, rel_err] = calculate_error(imp_value, ref_value); + + if (abs_err != 0 || rel_err != 0) { + const auto notifying = !in_roi || handler.handle_data(abs_err, rel_err); + + if (notifying) { + KAI_LOGE("Mismatched data at (", y, ", ", x, "): actual = ", imp_value, ", expected: ", ref_value); + } + } + } + } + + return handler.success(full_height * full_width); +} + /// Compares matrices with per-row quantization. template bool compare_per_row( const void* imp_data, const void* ref_data, const DataFormat& format, size_t full_height, size_t full_width, - const MatrixPortion& portion, MismatchHandler& handler) { + const Rect& rect, MismatchHandler& handler) { const auto block_height = format.block_height(); const auto block_width = format.block_width(); - const auto [rect_y, rect_x, rect_height, rect_width] = - portion.compute_portion(full_height, full_width, block_height, full_width); - KAI_ASSERT(format.scheduler_block_height(full_height) == block_height); - KAI_ASSERT(format.scheduler_block_width(full_width) == full_width); - KAI_ASSERT(rect_x == 0); - KAI_ASSERT(rect_width == full_width); + KAI_ASSUME(format.scheduler_block_height(full_height) == block_height); + KAI_ASSUME(format.scheduler_block_width(full_width) == full_width); + KAI_ASSUME(rect.start_col() == 0); + KAI_ASSUME(rect.width() == full_width); + + const auto data_bits = size_in_bits(); const auto num_groups = (full_height + block_height - 1) / block_height; const auto group_num_blocks = (full_width + block_width - 1) / block_width; const auto group_offsets_bytes = block_height * sizeof(Offset); const auto group_scales_bytes = block_height * sizeof(Scale); - const auto block_data_bytes = block_height * block_width * sizeof(Data); + const auto block_data_bytes = block_height * block_width * data_bits / 8; - const auto begin_group = rect_y / block_height; - const auto end_group = (rect_y + rect_height) / block_height; + const auto begin_group = rect.start_row() / block_height; + const auto end_group = rect.end_row() / block_height; const auto* imp_ptr = reinterpret_cast(imp_data); const auto* ref_ptr = reinterpret_cast(ref_data); @@ -91,8 +121,8 @@ bool compare_per_row( for (size_t block_no = 0; block_no < group_num_blocks; ++block_no) { for (size_t y = 0; y < block_height; ++y) { for (size_t x = 0; x < block_width; ++x) { - const auto imp_data = reinterpret_cast(imp_ptr)[y * block_width + x]; - const Data ref_data = in_roi ? reinterpret_cast(ref_ptr)[y * block_width + x] : 0; + const auto imp_data = read_array(imp_ptr, y * block_width + x); + const Data ref_data = in_roi ? read_array(ref_ptr, y * block_width + x) : Data(0); const auto [abs_err, rel_err] = calculate_error(imp_data, ref_data); if (abs_err != 0 || rel_err != 0) { @@ -133,23 +163,35 @@ bool compare_per_row( ref_ptr += group_scales_bytes; } - return handler.success(rect_height * full_width); + return handler.success(rect.height() * full_width); } } // namespace bool compare( const void* imp_data, const void* ref_data, const DataFormat& format, size_t full_height, size_t full_width, - const MatrixPortion& portion, MismatchHandler&& handler) { + const Rect& rect, MismatchHandler&& handler) { const auto data_type = format.data_type(); const auto scale_dt = format.scale_data_type(); const auto offset_dt = format.zero_point_data_type(); switch (format.quantization_format()) { + case DataFormat::QuantizationFormat::NONE: + switch (data_type) { + case DataType::FP32: + return compare_raw(imp_data, ref_data, full_height, full_width, rect, handler); + + default: + break; + } + case DataFormat::QuantizationFormat::PER_ROW: if (data_type == DataType::QI8 && scale_dt == DataType::FP32 && offset_dt == DataType::I32) { return compare_per_row( - imp_data, ref_data, format, full_height, full_width, portion, handler); + imp_data, ref_data, format, full_height, full_width, rect, handler); + } else if (data_type == DataType::QSI4 && scale_dt == DataType::FP32 && offset_dt == DataType::I32) { + return compare_per_row( + imp_data, ref_data, format, full_height, full_width, rect, handler); } break; @@ -182,15 +224,15 @@ DefaultMismatchHandler::DefaultMismatchHandler(const DefaultMismatchHandler& rhs _num_mismatches(0), _failed(false) { // Cannot copy mismatch handler that is already in use. - KAI_ASSERT(rhs._num_mismatches == 0); - KAI_ASSERT(!rhs._failed); + KAI_ASSUME(rhs._num_mismatches == 0); + KAI_ASSUME(!rhs._failed); } DefaultMismatchHandler& DefaultMismatchHandler::operator=(const DefaultMismatchHandler& rhs) { if (this != &rhs) { // Cannot copy mismatch handler that is already in use. - KAI_ASSERT(rhs._num_mismatches == 0); - KAI_ASSERT(!rhs._failed); + KAI_ASSUME(rhs._num_mismatches == 0); + KAI_ASSUME(!rhs._failed); _abs_error_threshold = rhs._abs_error_threshold; _rel_error_threshold = rhs._rel_error_threshold; @@ -221,7 +263,7 @@ bool DefaultMismatchHandler::success(size_t num_checks) const { } float mismatched_rate = static_cast(_num_mismatches) / static_cast(num_checks); - return _num_mismatches <= _abs_mismatched_threshold && mismatched_rate <= _rel_mismatched_threshold; + return _num_mismatches <= _abs_mismatched_threshold || mismatched_rate <= _rel_mismatched_threshold; } } // namespace kai::test diff --git a/test/common/compare.hpp b/test/common/compare.hpp index 04ac829ed368549972d5f045f4b514479372e86c..f740e46a78a410d35228ade9b813b8c65a7e27e9 100644 --- a/test/common/compare.hpp +++ b/test/common/compare.hpp @@ -11,7 +11,7 @@ namespace kai::test { class DataFormat; -class MatrixPortion; +class Rect; class MismatchHandler; /// Compares two matrices to check whether they are matched. @@ -21,13 +21,13 @@ class MismatchHandler; /// @param[in] format Data format. /// @param[in] full_height Height of the full matrix. /// @param[in] full_width Width of the full matrix. -/// @param[in] portion Portion of the matrix to be calculated in the actual implementation matrix. +/// @param[in] rect Rectangular region of the matrix that is populated with data. /// @param[in] handler Mismatch handler. /// /// @return `true` if the two matrices are considered matched. bool compare( const void* imp_data, const void* ref_data, const DataFormat& format, size_t full_height, size_t full_width, - const MatrixPortion& portion, MismatchHandler&& handler); + const Rect& rect, MismatchHandler&& handler); /// Handles mismatches found by @ref validate function. class MismatchHandler { diff --git a/test/common/data_format.cpp b/test/common/data_format.cpp index 9310ad933d83cf5f049028c5018db054d761840d..5506cc0f7ca93682aafeb53cbd09f5200c26c7df 100644 --- a/test/common/data_format.cpp +++ b/test/common/data_format.cpp @@ -16,13 +16,15 @@ namespace kai::test { DataFormat::DataFormat( DataType data_type, size_t block_height, size_t block_width, QuantizationFormat quant_format, DataType scale_dt, - DataType zero_point_dt) noexcept : + DataType zero_point_dt, size_t subblock_height, size_t subblock_width) noexcept : _data_type(data_type), _quant_format(quant_format), _scale_dt(scale_dt), _zero_point_dt(zero_point_dt), _block_height(block_height), - _block_width(block_width) { + _block_width(block_width), + _subblock_height(subblock_height), + _subblock_width(subblock_width) { } bool DataFormat::operator==(const DataFormat& rhs) const { @@ -50,6 +52,11 @@ DataType DataFormat::zero_point_data_type() const { return _zero_point_dt; } +bool DataFormat::is_raw() const { + return _quant_format == QuantizationFormat::NONE && // + _block_height == 0 && _block_width == 0 && _subblock_height == 0 && _subblock_width == 0; +} + size_t DataFormat::block_height() const { return _block_height; } @@ -58,9 +65,19 @@ size_t DataFormat::block_width() const { return _block_width; } +size_t DataFormat::subblock_height() const { + return _subblock_height; +} + +size_t DataFormat::subblock_width() const { + return _subblock_width; +} + size_t DataFormat::scheduler_block_height([[maybe_unused]] size_t full_height) const { switch (_quant_format) { case QuantizationFormat::NONE: + return _block_height > 0 ? _block_height : 1; + case QuantizationFormat::PER_ROW: return _block_height; @@ -72,7 +89,7 @@ size_t DataFormat::scheduler_block_height([[maybe_unused]] size_t full_height) c size_t DataFormat::scheduler_block_width(size_t full_width) const { switch (_quant_format) { case QuantizationFormat::NONE: - return _block_width; + return _block_width > 0 ? _block_width : 1; case QuantizationFormat::PER_ROW: return full_width; @@ -83,7 +100,7 @@ size_t DataFormat::scheduler_block_width(size_t full_width) const { } uintptr_t DataFormat::default_row_stride(size_t width) const { - const auto padded_width = round_up_multiple(width, _block_width); + const auto padded_width = round_up_multiple(width, _block_width > 0 ? _block_width : 1); switch (_quant_format) { case QuantizationFormat::NONE: @@ -106,12 +123,12 @@ uintptr_t DataFormat::default_offset_in_bytes(size_t row, size_t col, size_t wid switch (_quant_format) { case QuantizationFormat::NONE: - return row * row_stride + col; + return row * row_stride + col * data_type_size_in_bits(_data_type) / 8; case QuantizationFormat::PER_ROW: KAI_ASSERT(row % _block_height == 0); KAI_ASSERT(col == 0); - return (row / _block_height) * row_stride + col; + return (row / _block_height) * row_stride + col * data_type_size_in_bits(_data_type) / 8; default: KAI_ERROR("Unsupported quantization packing format!"); @@ -119,7 +136,7 @@ uintptr_t DataFormat::default_offset_in_bytes(size_t row, size_t col, size_t wid } size_t DataFormat::default_size_in_bytes(size_t height, size_t width) const { - const auto num_rows = (height + _block_height - 1) / _block_height; + const auto num_rows = _block_height > 0 ? (height + _block_height - 1) / _block_height : height; return num_rows * default_row_stride(width); } diff --git a/test/common/data_format.hpp b/test/common/data_format.hpp index 24a17574705a633ad856bc08ffccb02074e18b21..3aa19f71e1dec3c8b6699825694accb418aa09a2 100644 --- a/test/common/data_format.hpp +++ b/test/common/data_format.hpp @@ -13,6 +13,7 @@ namespace kai::test { +/// Data format. class DataFormat { public: /// Quantization packing format. @@ -24,15 +25,17 @@ public: /// Creates a new data format. /// /// @param[in] data_type Data type of data value. + /// @param[in] block_height Block height. + /// @param[in] block_width Block width. /// @param[in] quant_format Quantization packing format. /// @param[in] scale_dt Data type of scale value. /// @param[in] zero_point_dt Data type of zero point value. - /// @param[in] block_height Block height. - /// @param[in] block_width Block width. + /// @param[in] subblock_height Sub-block height. + /// @param[in] subblock_width Sub-block width. DataFormat( - DataType data_type, size_t block_height = 1, size_t block_width = 1, + DataType data_type, size_t block_height = 0, size_t block_width = 0, QuantizationFormat quant_format = QuantizationFormat::NONE, DataType scale_dt = DataType::UNKNOWN, - DataType zero_point_dt = DataType::UNKNOWN) noexcept; + DataType zero_point_dt = DataType::UNKNOWN, size_t subblock_height = 0, size_t subblock_width = 0) noexcept; /// Equality operator. [[nodiscard]] bool operator==(const DataFormat& rhs) const; @@ -52,12 +55,21 @@ public: /// Gets the data type of zero point value. [[nodiscard]] DataType zero_point_data_type() const; + /// Gets a value indicating whether this format has no blocking or packed quantization information. + [[nodiscard]] bool is_raw() const; + /// Gets the block height. [[nodiscard]] size_t block_height() const; /// Gets the block width. [[nodiscard]] size_t block_width() const; + /// Gets the sub-block height. + [[nodiscard]] size_t subblock_height() const; + + /// Gets the sub-block width. + [[nodiscard]] size_t subblock_width() const; + /// Gets the scheduling block height. /// /// @param[in] full_height Height of the full matrix. @@ -107,6 +119,8 @@ private: DataType _zero_point_dt; size_t _block_height; size_t _block_width; + size_t _subblock_height; + size_t _subblock_width; }; } // namespace kai::test diff --git a/test/common/data_type.cpp b/test/common/data_type.cpp index 2e7f7f389175c79eefef3cdc93cd62f5706cff36..a23ced1bc71fc58165e123684dd88aaaa6a5a832 100644 --- a/test/common/data_type.cpp +++ b/test/common/data_type.cpp @@ -69,12 +69,12 @@ bool data_type_is_signed(DataType dt) { } bool data_type_is_quantized(DataType dt) { - KAI_ASSERT(data_type_is_integral(dt)); + KAI_ASSERT_IF(has_q(dt), data_type_is_integral(dt)); return has_q(dt); } bool data_type_is_quantized_asymm(DataType dt) { - KAI_ASSERT(data_type_is_quantized(dt)); + KAI_ASSERT_IF(has_a(dt), data_type_is_quantized(dt)); return has_a(dt); } diff --git a/test/common/data_type.hpp b/test/common/data_type.hpp index 17aeedab0cf1452127bc8a9cf1b025af7aa4f7ae..8fbdcf9de24587a5a3da3b51a34758a184cb2032 100644 --- a/test/common/data_type.hpp +++ b/test/common/data_type.hpp @@ -42,6 +42,9 @@ enum class DataType : uint16_t { I32 = 0b1'1'0'0'0000'00100000, ///< 32-bit signed integer. QI8 = 0b1'1'1'1'0000'00001000, ///< 8-bit asymmetric quantized. + + QSU4 = 0b1'0'1'0'0000'00000100, ///< 4-bit unsigned symmetric quantized. + QSI4 = 0b1'1'1'0'0000'00000100, ///< 4-bit signed symmetric quantized. }; /// Gets the size in bits of the specified data type. diff --git a/test/common/int4.cpp b/test/common/int4.cpp new file mode 100644 index 0000000000000000000000000000000000000000..07be1836c047f5c5aef7c17ca7530c83912cff47 --- /dev/null +++ b/test/common/int4.cpp @@ -0,0 +1,95 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/common/int4.hpp" + +#include + +namespace kai::test { + +UInt4& UInt4::operator=(uint8_t value) { + _value = value; + return *this; +} + +UInt4::operator int32_t() const { + return _value; +} + +UInt4::operator float() const { + return _value; +} + +UInt4 UInt4::operator+(UInt4 rhs) const { + return UInt4(_value + rhs._value); +} + +UInt4 UInt4::operator-(UInt4 rhs) const { + return UInt4(_value - rhs._value); +} + +UInt4 UInt4::operator*(UInt4 rhs) const { + return UInt4(_value * rhs._value); +} + +UInt4 UInt4::operator/(UInt4 rhs) const { + return UInt4(_value / rhs._value); +} + +uint8_t UInt4::pack_u8(UInt4 low, UInt4 high) { + return (low._value & 0x0F) | (high._value << 4); +} + +std::tuple UInt4::unpack_u8(uint8_t value) { + const uint8_t low = value & 0x0F; + const uint8_t high = value >> 4; + + return {UInt4(low), UInt4(high)}; +} + +// ===================================================================================================================== + +Int4& Int4::operator=(int8_t value) { + _value = value; + return *this; +} + +Int4::operator int32_t() const { + return _value; +} + +Int4::operator float() const { + return _value; +} + +Int4 Int4::operator+(Int4 rhs) const { + return Int4(_value + rhs._value); +} + +Int4 Int4::operator-(Int4 rhs) const { + return Int4(_value - rhs._value); +} + +Int4 Int4::operator*(Int4 rhs) const { + return Int4(_value * rhs._value); +} + +Int4 Int4::operator/(Int4 rhs) const { + return Int4(_value / rhs._value); +} + +uint8_t Int4::pack_u8(Int4 low, Int4 high) { + return (low._value & 0x0F) | (high._value << 4); +} + +std::tuple Int4::unpack_u8(uint8_t value) { + const int8_t low = static_cast(value << 4) >> 4; + const int8_t high = static_cast(value) >> 4; + + return {Int4(low), Int4(high)}; +} + +} // namespace kai::test diff --git a/test/common/int4.hpp b/test/common/int4.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7fbbb1afee341a830bc1d5cc95099dbdef830c11 --- /dev/null +++ b/test/common/int4.hpp @@ -0,0 +1,112 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +namespace kai::test { + +/// 4-bit unsigned integer. +class UInt4 { +public: + /// Creates a new 4-bit unsigned integer value. + /// + /// @param[in] value Value. + constexpr explicit UInt4(uint8_t value) : _value(value) { + } + + /// Assignment operator. + UInt4& operator=(uint8_t value); + + /// Conversion operator. + operator int32_t() const; + + /// Conversion operator. + operator float() const; + + /// Addition operator. + [[nodiscard]] UInt4 operator+(UInt4 rhs) const; + + /// Subtraction operator. + [[nodiscard]] UInt4 operator-(UInt4 rhs) const; + + /// Multiplication operator. + [[nodiscard]] UInt4 operator*(UInt4 rhs) const; + + /// Division operator. + [[nodiscard]] UInt4 operator/(UInt4 rhs) const; + + /// Packs two 4-bit unsigned integer values into one byte. + /// + /// @param[in] low Low nibble. + /// @param[in] high High nibble. + /// + /// @return The packed byte. + [[nodiscard]] static uint8_t pack_u8(UInt4 low, UInt4 high); + + /// Unpacks one byte to two 4-bit unsigned integer values. + /// + /// @param[in] value 8-bit packed value. + /// + /// @return The low and high nibbles. + [[nodiscard]] static std::tuple unpack_u8(uint8_t value); + +private: + uint8_t _value; +}; + +/// 4-bit signed integer. +class Int4 { +public: + /// Creates a new 4-bit signed integer value. + /// + /// @param[in] value Value. + constexpr explicit Int4(int8_t value) : _value(value) { + } + + /// Assignment operator. + Int4& operator=(int8_t value); + + /// Conversion operator. + operator int32_t() const; + + /// Conversion operator. + operator float() const; + + /// Addition operator. + [[nodiscard]] Int4 operator+(Int4 rhs) const; + + /// Subtraction operator. + [[nodiscard]] Int4 operator-(Int4 rhs) const; + + /// Multiplication operator. + [[nodiscard]] Int4 operator*(Int4 rhs) const; + + /// Division operator. + [[nodiscard]] Int4 operator/(Int4 rhs) const; + + /// Packs two 4-bit signed integer values into one byte. + /// + /// @param[in] low Low nibble. + /// @param[in] high High nibble. + /// + /// @return The packed byte. + [[nodiscard]] static uint8_t pack_u8(Int4 low, Int4 high); + + /// Unpacks one byte to two 4-bit signed integer values. + /// + /// @param[in] value 8-bit packed value. + /// + /// @return The low and high nibbles. + [[nodiscard]] static std::tuple unpack_u8(uint8_t value); + +private: + int8_t _value; +}; + +} // namespace kai::test diff --git a/test/common/logging.hpp b/test/common/logging.hpp index ae1f11c03c7021b83e1fc3ec8cd1b682466d558e..ff6130a03a644431a1270588453f8ae478de0a2f 100644 --- a/test/common/logging.hpp +++ b/test/common/logging.hpp @@ -6,28 +6,63 @@ #pragma once +#include #include #include +#include + +#include "test/common/int4.hpp" #define KAI_LOGE(...) kai::test::detail::log("ERROR", __VA_ARGS__) namespace kai::test::detail { +/// Prints the specified value to standard error. +/// +/// @tparam T Data type. +/// +/// @param[in] value Value to be printed out. template void write_log_content(T&& value) { - std::cerr << value; + using TT = std::decay_t; + + if constexpr (std::is_same_v) { + std::cerr << static_cast(value); + } else if constexpr (std::is_same_v) { + std::cerr << static_cast(value); + } else if constexpr (std::is_same_v) { + std::cerr << static_cast(value); + } else if constexpr (std::is_same_v) { + std::cerr << static_cast(value); + } else { + std::cerr << value; + } } +/// Prints the specified values to standard error. +/// +/// @tparam T Data type of the first value. +/// @tparam Ts Data types of the subsequent values. +/// +/// @param[in] value First value to be printed out. +/// @param[in] others Subsequent values to be printed out. template void write_log_content(T&& value, Ts&&... others) { write_log_content(std::forward(value)); write_log_content(std::forward(others)...); } +/// Prints the log to standard error. +/// +/// @tparam Ts Data types of values to be printed out. +/// +/// @param[in] level Severity level. +/// @param[in] args Values to be printed out. template void log(std::string_view level, Ts&&... args) { std::cerr << level << " | "; write_log_content(std::forward(args)...); + std::cerr << "\n"; } } // namespace kai::test::detail diff --git a/test/common/matrix_portion.cpp b/test/common/matrix_portion.cpp index a6844ed313ea6a0d92ef0490b13850a01943cfa7..48567676722d119cd17cfda93304717d980353c8 100644 --- a/test/common/matrix_portion.cpp +++ b/test/common/matrix_portion.cpp @@ -8,8 +8,8 @@ #include #include -#include +#include "test/common/rect.hpp" #include "test/reference/round.hpp" namespace kai::test { @@ -18,7 +18,23 @@ MatrixPortion::MatrixPortion(float start_row, float start_col, float height, flo _start_row(start_row), _start_col(start_col), _height(height), _width(width) { } -std::tuple MatrixPortion::compute_portion( +float MatrixPortion::start_row() const { + return _start_row; +} + +float MatrixPortion::start_col() const { + return _start_col; +} + +float MatrixPortion::height() const { + return _height; +} + +float MatrixPortion::width() const { + return _width; +} + +Rect MatrixPortion::compute_portion( size_t full_height, size_t full_width, size_t scheduler_block_height, size_t scheduler_block_width) const { const auto start_row_f = std::clamp(_start_row, 0, 1); const auto start_col_f = std::clamp(_start_col, 0, 1); diff --git a/test/common/matrix_portion.hpp b/test/common/matrix_portion.hpp index 3b38220a2236a3e2700eb3ef8ce6e5d589cca371..02f9f5481988ac5ac81f382f7e1faaa3794fcd8b 100644 --- a/test/common/matrix_portion.hpp +++ b/test/common/matrix_portion.hpp @@ -7,13 +7,16 @@ #pragma once #include -#include + +#include "test/common/rect.hpp" namespace kai::test { /// Portion of a matrix. /// /// This class is used to define the sub-matrix a test is running and checking. +/// +/// This is the relative version of @ref Rect. class MatrixPortion { public: /// Creates a new matrix portion. @@ -24,6 +27,18 @@ public: /// @param[in] width Portion width as the ratio to the width of the matrix. MatrixPortion(float start_row, float start_col, float height, float width); + /// Gets the starting row as the ratio to the height of the matrix. + [[nodiscard]] float start_row() const; + + /// Gets the starting column as the ratio to the width of the matrix. + [[nodiscard]] float start_col() const; + + /// Gets the portion height as the ratio to the height of the matrix. + [[nodiscard]] float height() const; + + /// Gets the portion width as the ratio to the width of the matrix. + [[nodiscard]] float width() const; + /// Computes the starting coordinate and the shape of the sub-matrix. /// /// Requirements: @@ -39,8 +54,8 @@ public: /// @param[in] scheduler_block_height Block height for scheduling purpose. /// @param[in] scheduler_block_width Block width for scheduling purpose. /// - /// @return The starting row, starting column, height and width of the sub-matrix. - [[nodiscard]] std::tuple compute_portion( + /// @return The rectangular region of the matrix. + [[nodiscard]] Rect compute_portion( size_t full_height, size_t full_width, size_t scheduler_block_height, size_t scheduler_block_width) const; private: diff --git a/test/common/memory.hpp b/test/common/memory.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9091fc6813a693637dbf1b0e32df73c4d2992138 --- /dev/null +++ b/test/common/memory.hpp @@ -0,0 +1,83 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#include "test/common/int4.hpp" + +namespace kai::test { + +/// Gets the size in bits of type `T`. +template +constexpr size_t size_in_bits() { + return sizeof(T) * 8; +} + +/// Gets the size in bits of type `T`. +template <> +constexpr size_t size_in_bits() { + return 4; +} + +/// Gets the size in bits of type `T`. +template <> +constexpr size_t size_in_bits() { + return 4; +} + +/// Reads the array at the specified index. +/// +/// @param[in] array Data buffer. +/// @param[in] index Array index. +/// +/// @return The array value at the specified index. +template +T read_array(const void* array, size_t index) { + if constexpr (std::is_same_v) { + const auto [lo, hi] = UInt4::unpack_u8(reinterpret_cast(array)[index / 2]); + return index % 2 == 0 ? lo : hi; + } else if constexpr (std::is_same_v) { + const auto [lo, hi] = Int4::unpack_u8(reinterpret_cast(array)[index / 2]); + return index % 2 == 0 ? lo : hi; + } else { + return reinterpret_cast(array)[index]; + } +} + +/// Writes the specified value to the array. +/// +/// @param[in] array Data buffer. +/// @param[in] index Array index. +/// @param[in] value Value to be stored. +template +void write_array(void* array, size_t index, T value) { + if constexpr (std::is_same_v) { + auto* arr_value = reinterpret_cast(array) + index / 2; + const auto [lo, hi] = UInt4::unpack_u8(*arr_value); + + if (index % 2 == 0) { + *arr_value = UInt4::pack_u8(value, hi); + } else { + *arr_value = UInt4::pack_u8(lo, value); + } + } else if constexpr (std::is_same_v) { + auto* arr_value = reinterpret_cast(array) + index / 2; + const auto [lo, hi] = Int4::unpack_u8(*arr_value); + + if (index % 2 == 0) { + *arr_value = Int4::pack_u8(value, hi); + } else { + *arr_value = Int4::pack_u8(lo, value); + } + } else { + reinterpret_cast(array)[index] = value; + } +} + +} // namespace kai::test diff --git a/test/common/numeric_limits.hpp b/test/common/numeric_limits.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e59508108f5c1751c534c31c520c622c49e5df61 --- /dev/null +++ b/test/common/numeric_limits.hpp @@ -0,0 +1,39 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#include "test/common/int4.hpp" + +namespace kai::test { + +/// Highest finite value of type `T`. +template +inline constexpr T numeric_highest = std::numeric_limits::max(); + +/// Highest finite value of type `T`. +template <> +inline constexpr UInt4 numeric_highest{15}; + +/// Highest finite value of type `T`. +template <> +inline constexpr Int4 numeric_highest{7}; + +/// Lowest finite value of type `T`. +template +inline constexpr T numeric_lowest = std::numeric_limits::lowest(); + +/// Lowest finite value of type `T`. +template <> +inline constexpr UInt4 numeric_lowest{0}; + +/// Lowest finite value of type `T`. +template <> +inline constexpr Int4 numeric_lowest{-8}; + +} // namespace kai::test diff --git a/test/common/printer.cpp b/test/common/printer.cpp index 256ea970a7ce898ea630d07f43e66093010e2a10..368ae7db8299919025b500d24ac86f6e7584f579 100644 --- a/test/common/printer.cpp +++ b/test/common/printer.cpp @@ -12,31 +12,44 @@ #include "src/kai_common.h" #include "test/common/data_format.hpp" #include "test/common/data_type.hpp" +#include "test/common/int4.hpp" namespace kai::test { namespace { inline void print_data(std::ostream& os, const uint8_t* data, size_t len, DataType data_type) { - for (size_t i = 0; i < len; ++i) { - switch (data_type) { - case DataType::FP32: - os << reinterpret_cast(data)[i]; - break; - - case DataType::I32: - os << reinterpret_cast(data)[i]; - break; - - case DataType::QI8: - os << static_cast(reinterpret_cast(data)[i]); - break; - - default: - KAI_ERROR("Unsupported data type!"); + if (data_type == DataType::QSU4) { + for (size_t i = 0; i < len / 2; ++i) { + const auto [low, high] = UInt4::unpack_u8(data[i]); + os << static_cast(low) << ", " << static_cast(high) << ", "; + } + } else if (data_type == DataType::QSI4) { + for (size_t i = 0; i < len / 2; ++i) { + const auto [low, high] = Int4::unpack_u8(data[i]); + os << static_cast(low) << ", " << static_cast(high) << ", "; + } + } else { + for (size_t i = 0; i < len; ++i) { + switch (data_type) { + case DataType::FP32: + os << reinterpret_cast(data)[i]; + break; + + case DataType::I32: + os << reinterpret_cast(data)[i]; + break; + + case DataType::QI8: + os << static_cast(reinterpret_cast(data)[i]); + break; + + default: + KAI_ERROR("Unsupported data type!"); + } + + os << ", "; } - - os << ", "; } } @@ -47,7 +60,7 @@ void print_matrix_raw(std::ostream& os, const uint8_t* data, DataType data_type, for (size_t y = 0; y < height; ++y) { os << " ["; print_data(os, data + y * row_stride, width, data_type); - os << "]\n"; + os << "],\n"; } os << "]\n"; } diff --git a/test/common/rect.cpp b/test/common/rect.cpp new file mode 100644 index 0000000000000000000000000000000000000000..457ac2f3573021a6ff77076071578bbca1f1ec4e --- /dev/null +++ b/test/common/rect.cpp @@ -0,0 +1,41 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/common/rect.hpp" + +#include + +namespace kai::test { + +Rect::Rect(size_t start_row, size_t start_col, size_t height, size_t width) : + _start_row(start_row), _start_col(start_col), _height(height), _width(width) { +} + +size_t Rect::start_row() const { + return _start_row; +} + +size_t Rect::start_col() const { + return _start_col; +} + +size_t Rect::height() const { + return _height; +} + +size_t Rect::width() const { + return _width; +} + +size_t Rect::end_row() const { + return _start_row + _height; +} + +size_t Rect::end_col() const { + return _start_col + _width; +} + +} // namespace kai::test diff --git a/test/common/rect.hpp b/test/common/rect.hpp new file mode 100644 index 0000000000000000000000000000000000000000..92b033b486bdd44d7cb76680f46438c10274e4c5 --- /dev/null +++ b/test/common/rect.hpp @@ -0,0 +1,51 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +namespace kai::test { + +/// Rectangular region of a matrix. +/// +/// This is the absolute version of @ref MatrixPortion. +class Rect { +public: + /// Creates a new rectangular region of a matrix. + /// + /// @param[in] start_row Starting row index. + /// @param[in] start_col Starting column index. + /// @param[in] height Number of rows. + /// @param[in] width Number of columns. + Rect(size_t start_row, size_t start_col, size_t height, size_t width); + + /// Gets the starting row index. + [[nodiscard]] size_t start_row() const; + + /// Gets the starting column index. + [[nodiscard]] size_t start_col() const; + + /// Gets the number of rows. + [[nodiscard]] size_t height() const; + + /// Gets the number of columns. + [[nodiscard]] size_t width() const; + + /// Gets the end (exclusive) row index. + [[nodiscard]] size_t end_row() const; + + /// Gets the end (exclusive) column index. + [[nodiscard]] size_t end_col() const; + +private: + size_t _start_row; + size_t _start_col; + size_t _height; + size_t _width; +}; + +} // namespace kai::test diff --git a/test/common/type_traits.hpp b/test/common/type_traits.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2267432a9bdd2f8a4886c689822d4e32fc242467 --- /dev/null +++ b/test/common/type_traits.hpp @@ -0,0 +1,79 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +namespace kai::test { + +class UInt4; +class Int4; + +/// `true` if `T` is unsigned numeric type. +template +inline constexpr bool is_unsigned = std::is_unsigned_v; + +/// `true` if `T` is unsigned numeric type. +template <> +inline constexpr bool is_unsigned = true; + +/// `true` if `T` is unsigned numeric type. +template <> +inline constexpr bool is_unsigned = true; + +/// `true` if `T` is signed numeric type. +template +inline constexpr bool is_signed = std::is_signed_v; + +/// `true` if `T` is signed numeric type. +template <> +inline constexpr bool is_signed = false; + +/// `true` if `T` is signed numeric type. +template <> +inline constexpr bool is_signed = false; + +/// `true` if `T` is integral numeric type. +template +inline constexpr bool is_integral = std::is_integral_v; + +/// `true` if `T` is integral numeric type. +template <> +inline constexpr bool is_integral = true; + +/// `true` if `T` is integral numeric type. +template <> +inline constexpr bool is_integral = true; + +/// `true` if `T` is floating-point type. +template +inline constexpr bool is_floating_point = std::is_floating_point_v; + +/// Signed version of type `T`. +template +struct make_signed { + using type = std::make_signed_t; +}; + +/// Signed version of type `T`. +template <> +struct make_signed { + using type = Int4; +}; + +/// Signed version of type `T`. +template <> +struct make_signed { + using type = Int4; +}; + +/// Signed version of type `T`. +template +using make_signed_t = typename make_signed::type; + +} // namespace kai::test diff --git a/test/reference/binary_elementwise.cpp b/test/reference/binary_elementwise.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e37f76411fa8947134229855f1cf2ae7d30f7cc7 --- /dev/null +++ b/test/reference/binary_elementwise.cpp @@ -0,0 +1,149 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/reference/binary_elementwise.hpp" + +#include +#include +#include +#include + +#include "src/kai_common.h" +#include "test/common/data_type.hpp" +#include "test/common/int4.hpp" +#include "test/common/memory.hpp" + +namespace kai::test { + +namespace { + +/// Binary element-wise operator. +enum class BinaryElementwiseOperator : uint32_t { + ADD, ///< Addition. + SUB, ///< Subtraction. + MUL, ///< Multiplication. + DIV, ///< Division. +}; + +/// Scalar binary element-wise function. +/// +/// @tparam op Binary element-wise operator to perform. +/// @tparam T Data type. +/// +/// @param[in] lhs LHS operand. +/// @param[in] rhs RHS operand. +/// +/// @return The result of the operation. +template +T scalar_binary_elementwise(T lhs, T rhs) { + if constexpr (op == BinaryElementwiseOperator::ADD) { + return lhs + rhs; + } else if constexpr (op == BinaryElementwiseOperator::SUB) { + return lhs - rhs; + } else if constexpr (op == BinaryElementwiseOperator::MUL) { + return lhs * rhs; + } else if constexpr (op == BinaryElementwiseOperator::DIV) { + return lhs / rhs; + } else { + KAI_ERROR("Unsupported binary element-wise operator!"); + } +} + +/// Binary element-wise function. +/// +/// @tparam op Binary element-wise operator to perform. +/// @tparam T Data type. +/// +/// @param[in] lhs LHS data buffer. +/// @param[in] rhs RHS data buffer. +/// @param[in] lhs_height LHS height. +/// @param[in] lhs_width LHS width. +/// @param[in] rhs_height RHS height. +/// @param[in] rhs_width RHS width. +/// +/// @return The result data buffer. +template +std::vector binary_elementwise_any_op_type( + const void* lhs, const void* rhs, size_t lhs_height, size_t lhs_width, size_t rhs_height, size_t rhs_width) { + const auto height = std::max(lhs_height, rhs_height); + const auto width = std::max(lhs_width, rhs_width); + + std::vector dst; + dst.resize(height * width * size_in_bits() / 8); + KAI_ASSUME(width * size_in_bits() % 8 == 0); + + for (size_t y = 0; y < height; ++y) { + for (size_t x = 0; x < width; ++x) { + const auto lhs_y = lhs_height > 1 ? y : 0; + const auto lhs_x = lhs_width > 1 ? x : 0; + const auto lhs_value = read_array(lhs, lhs_y * lhs_width + lhs_x); + + const auto rhs_y = rhs_height > 1 ? y : 0; + const auto rhs_x = rhs_width > 1 ? x : 0; + const auto rhs_value = read_array(rhs, rhs_y * rhs_width + rhs_x); + + const auto dst_value = scalar_binary_elementwise(lhs_value, rhs_value); + write_array(dst.data(), y * width + x, dst_value); + } + } + + return dst; +} + +template +std::vector binary_elementwise_any_type( + const void* lhs, DataType lhs_dt, size_t lhs_height, size_t lhs_width, // + const void* rhs, DataType rhs_dt, size_t rhs_height, size_t rhs_width) { + KAI_ASSUME(lhs_dt == rhs_dt); + KAI_ASSUME(lhs_height == 1 || rhs_height == 1 || lhs_height == rhs_height); + KAI_ASSUME(lhs_width == 1 || rhs_width == 1 || lhs_width == rhs_width); + + switch (lhs_dt) { + case DataType::FP32: + return binary_elementwise_any_op_type(lhs, rhs, lhs_height, lhs_width, rhs_height, rhs_width); + + case DataType::I32: + return binary_elementwise_any_op_type(lhs, rhs, lhs_height, lhs_width, rhs_height, rhs_width); + + case DataType::QSU4: + return binary_elementwise_any_op_type(lhs, rhs, lhs_height, lhs_width, rhs_height, rhs_width); + + default: + KAI_ERROR("Unsupported data type!"); + } +} + +} // namespace + +std::vector add( + const void* lhs, DataType lhs_dt, size_t lhs_height, size_t lhs_width, // + const void* rhs, DataType rhs_dt, size_t rhs_height, size_t rhs_width) { + return binary_elementwise_any_type( + lhs, lhs_dt, lhs_height, lhs_width, rhs, rhs_dt, rhs_height, rhs_width); +} + +std::vector sub( + const void* lhs, DataType lhs_dt, size_t lhs_height, size_t lhs_width, // + const void* rhs, DataType rhs_dt, size_t rhs_height, size_t rhs_width) { + return binary_elementwise_any_type( + lhs, lhs_dt, lhs_height, lhs_width, rhs, rhs_dt, rhs_height, rhs_width); +} + +std::vector mul( + const void* lhs, DataType lhs_dt, size_t lhs_height, size_t lhs_width, // + const void* rhs, DataType rhs_dt, size_t rhs_height, size_t rhs_width) { + return binary_elementwise_any_type( + lhs, lhs_dt, lhs_height, lhs_width, rhs, rhs_dt, rhs_height, rhs_width); +} + +std::vector div( + const void* lhs, DataType lhs_dt, size_t lhs_height, size_t lhs_width, // + const void* rhs, DataType rhs_dt, size_t rhs_height, size_t rhs_width) { + return binary_elementwise_any_type( + lhs, lhs_dt, lhs_height, lhs_width, rhs, rhs_dt, rhs_height, rhs_width); +} + +} // namespace kai::test diff --git a/test/reference/binary_elementwise.hpp b/test/reference/binary_elementwise.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e3d3a9e1fdc36c7846859ebced2e2a09ae8938a5 --- /dev/null +++ b/test/reference/binary_elementwise.hpp @@ -0,0 +1,89 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include + +#include "test/common/data_type.hpp" + +namespace kai::test { + +/// Elementwise addition. +/// +/// Broadcasting is supported for any dimension and both LHS and RHS operands. +/// +/// @param[in] lhs LHS data buffer. +/// @param[in] lhs_dt LHS data type. +/// @param[in] lhs_height LHS height. +/// @param[in] lhs_width LHS width. +/// @param[in] rhs RHS data buffer. +/// @param[in] rhs_dt RHS data type. +/// @param[in] rhs_height RHS height. +/// @param[in] rhs_width RHS width. +/// +/// @return The result matrix. +std::vector add( + const void* lhs, DataType lhs_dt, size_t lhs_height, size_t lhs_width, // + const void* rhs, DataType rhs_dt, size_t rhs_height, size_t rhs_width); + +/// Elementwise subtraction. +/// +/// Broadcasting is supported for any dimension and both LHS and RHS operands. +/// +/// @param[in] lhs LHS data buffer. +/// @param[in] lhs_dt LHS data type. +/// @param[in] lhs_height LHS height. +/// @param[in] lhs_width LHS width. +/// @param[in] rhs RHS data buffer. +/// @param[in] rhs_dt RHS data type. +/// @param[in] rhs_height RHS height. +/// @param[in] rhs_width RHS width. +/// +/// @return The result matrix. +std::vector sub( + const void* lhs, DataType lhs_dt, size_t lhs_height, size_t lhs_width, // + const void* rhs, DataType rhs_dt, size_t rhs_height, size_t rhs_width); + +/// Elementwise multiplication. +/// +/// Broadcasting is supported for any dimension and both LHS and RHS operands. +/// +/// @param[in] lhs LHS data buffer. +/// @param[in] lhs_dt LHS data type. +/// @param[in] lhs_height LHS height. +/// @param[in] lhs_width LHS width. +/// @param[in] rhs RHS data buffer. +/// @param[in] rhs_dt RHS data type. +/// @param[in] rhs_height RHS height. +/// @param[in] rhs_width RHS width. +/// +/// @return The result matrix. +std::vector mul( + const void* lhs, DataType lhs_dt, size_t lhs_height, size_t lhs_width, // + const void* rhs, DataType rhs_dt, size_t rhs_height, size_t rhs_width); + +/// Elementwise division. +/// +/// Broadcasting is supported for any dimension and both LHS and RHS operands. +/// +/// @param[in] lhs LHS data buffer. +/// @param[in] lhs_dt LHS data type. +/// @param[in] lhs_height LHS height. +/// @param[in] lhs_width LHS width. +/// @param[in] rhs RHS data buffer. +/// @param[in] rhs_dt RHS data type. +/// @param[in] rhs_height RHS height. +/// @param[in] rhs_width RHS width. +/// +/// @return The result matrix. +std::vector div( + const void* lhs, DataType lhs_dt, size_t lhs_height, size_t lhs_width, // + const void* rhs, DataType rhs_dt, size_t rhs_height, size_t rhs_width); + +} // namespace kai::test diff --git a/test/reference/fill.cpp b/test/reference/fill.cpp index 4d7e326788d9fefd3675455a6a99ed1c860db203..ef49cb0f4cebffebd2f16ee3df79ce49003d2430 100644 --- a/test/reference/fill.cpp +++ b/test/reference/fill.cpp @@ -15,19 +15,25 @@ #include "src/kai_common.h" #include "test/common/data_format.hpp" +#include "test/common/int4.hpp" +#include "test/common/memory.hpp" namespace kai::test { +namespace { + template std::vector fill_matrix_raw(size_t height, size_t width, std::function gen) { - const auto size = height * width * sizeof(T); + const auto size = height * width * size_in_bits() / 8; + KAI_ASSUME(width * size_in_bits() % 8 == 0); + std::vector data; data.resize(size); auto ptr = reinterpret_cast(data.data()); for (size_t y = 0; y < height; ++y) { for (size_t x = 0; x < width; ++x) { - ptr[y * width + x] = gen(y, x); + write_array(ptr, y * width + x, gen(y, x)); } } @@ -35,7 +41,7 @@ std::vector fill_matrix_raw(size_t height, size_t width, std::function< } template -std::vector fill_matrix_random_raw(size_t height, size_t width, [[maybe_unused]] uint64_t seed) { +std::vector fill_matrix_random_raw(size_t height, size_t width, uint64_t seed) { using TDist = std::conditional_t< std::is_floating_point_v, std::uniform_real_distribution, std::uniform_int_distribution>; @@ -45,6 +51,24 @@ std::vector fill_matrix_random_raw(size_t height, size_t width, [[maybe return fill_matrix_raw(height, width, [&](size_t, size_t) { return dist(rnd); }); } +template <> +std::vector fill_matrix_random_raw(size_t height, size_t width, uint64_t seed) { + std::mt19937 rnd(seed); + std::uniform_int_distribution dist(-8, 7); + + return fill_matrix_raw(height, width, [&](size_t, size_t) { return Int4(dist(rnd)); }); +} + +template <> +std::vector fill_matrix_random_raw(size_t height, size_t width, uint64_t seed) { + std::mt19937 rnd(seed); + std::uniform_int_distribution dist(0, 15); + + return fill_matrix_raw(height, width, [&](size_t, size_t) { return UInt4(dist(rnd)); }); +} + +} // namespace + std::vector fill_matrix_random(size_t height, size_t width, const DataFormat& format, uint64_t seed) { switch (format.quantization_format()) { case DataFormat::QuantizationFormat::NONE: @@ -52,6 +76,12 @@ std::vector fill_matrix_random(size_t height, size_t width, const DataF case DataType::FP32: return fill_matrix_random_raw(height, width, seed); + case DataType::QSU4: + return fill_matrix_random_raw(height, width, seed); + + case DataType::QSI4: + return fill_matrix_random_raw(height, width, seed); + default: KAI_ERROR("Unsupported data type!"); } diff --git a/test/reference/matmul.cpp b/test/reference/matmul.cpp new file mode 100644 index 0000000000000000000000000000000000000000..960de3c2dbc60b3fd8e8eebde0755d1fd93b44ad --- /dev/null +++ b/test/reference/matmul.cpp @@ -0,0 +1,154 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/reference/matmul.hpp" + +#include +#include +#include +#include + +#include "src/kai_common.h" +#include "test/common/data_format.hpp" +#include "test/common/data_type.hpp" +#include "test/common/int4.hpp" +#include "test/common/memory.hpp" +#include "test/common/printer.hpp" +#include "test/reference/binary_elementwise.hpp" +#include "test/reference/pack.hpp" +#include "test/reference/quantize.hpp" +#include "test/reference/reduce.hpp" + +namespace kai::test { + +namespace { + +/// Matrix multiplication. +/// +/// @tparam T Data type. +/// +/// @param[in] lhs LHS operand data buffer. +/// @param[in] rhs RHS operand data buffer. +/// @param[in] m Output height. +/// @param[in] n Output width. +/// @param[in] k Non-transposed LHS width and non-transposed RHS height. +/// @param[in] lhs_transposed `true` if LHS operand is transposed. +/// @param[in] rhs_transposed `true` if RHS operand is transposed. +/// +/// @return The result data buffer. +template +std::vector matmul_any_type( + const void* lhs, const void* rhs, // + size_t m, size_t n, size_t k, // + bool lhs_transposed, bool rhs_transposed) { + const auto lhs_m_stride = lhs_transposed ? 1 : k; + const auto lhs_k_stride = lhs_transposed ? m : 1; + + const auto rhs_n_stride = rhs_transposed ? k : 1; + const auto rhs_k_stride = rhs_transposed ? 1 : n; + + std::vector dst; + dst.resize(m * n * size_in_bits() / 8); + + for (size_t im = 0; im < m; ++im) { + for (size_t in = 0; in < n; ++in) { + T acc{0}; + + for (size_t ik = 0; ik < k; ++ik) { + const auto lhs_value = read_array(lhs, im * lhs_m_stride + ik * lhs_k_stride); + const auto rhs_value = read_array(rhs, in * rhs_n_stride + ik * rhs_k_stride); + acc += lhs_value * rhs_value; + } + + write_array(dst.data(), im * n + in, acc); + } + } + + return dst; +} + +} // namespace + +std::vector matmul_pack_rhs( + const void* data, const void* scales, const void* zero_points, const DataFormat& src_format, + const DataFormat& dst_format, size_t height, size_t width) { + const auto src_dt = src_format.data_type(); + const auto src_qf = src_format.quantization_format(); + + const auto dst_dt = dst_format.data_type(); + const auto dst_qf = dst_format.quantization_format(); + + std::vector tmp_data; + std::vector tmp_scales; + std::vector tmp_zero_points; + + if (src_dt == DataType::QSU4 && src_qf == DataFormat::QuantizationFormat::NONE && // + dst_dt == DataType::QSI4 && dst_qf == DataFormat::QuantizationFormat::PER_ROW) { + // For this specific RHS format conversion: + // + // * 4-bit data is added by 8. + // * Scale is divided by 16. + // * Zero point is accumulation of all values in the same row. + + KAI_ASSUME(zero_points == nullptr); + const int32_t zero_point = 8; + const uint8_t zero_point_i4 = UInt4::pack_u8(UInt4(zero_point), UInt4(zero_point)); + const int32_t row_zero_point = zero_point * width; + + KAI_ASSUME(dst_format.subblock_width() > 0); + const int32_t subblock_width_i32 = dst_format.subblock_width(); + const float subblock_width_f = dst_format.subblock_width(); + + tmp_zero_points = reduce_add(data, src_format, height, width, DataFormat(DataType::I32), 0); + tmp_zero_points = sub(tmp_zero_points.data(), DataType::I32, height, 1, &row_zero_point, DataType::I32, 1, 1); + tmp_zero_points = + mul(tmp_zero_points.data(), DataType::I32, height, 1, &subblock_width_i32, DataType::I32, 1, 1); + zero_points = tmp_zero_points.data(); + + tmp_data = add(data, DataType::QSU4, height, width, &zero_point_i4, DataType::QSU4, 1, 1); + data = tmp_data.data(); + + tmp_scales = div(scales, DataType::FP32, height, 1, &subblock_width_f, DataType::FP32, 1, 1); + scales = tmp_scales.data(); + } + + return pack(dst_format, data, scales, zero_points, src_format, height, width); +} + +std::vector matmul( + const void* lhs, const void* lhs_scales, const void* lhs_zero_points, DataType lhs_dt, // + const void* rhs, const void* rhs_scales, const void* rhs_zero_points, DataType rhs_dt, // + DataType dst_dt, // + size_t m, size_t n, size_t k, // + bool lhs_transposed, bool rhs_transposed) { + const auto lhs_h = lhs_transposed ? k : m; + const auto lhs_w = lhs_transposed ? m : k; + + const auto rhs_h = rhs_transposed ? n : k; + const auto rhs_w = rhs_transposed ? k : n; + + std::vector tmp_lhs; + std::vector tmp_rhs; + + if (data_type_is_quantized(lhs_dt)) { + tmp_lhs = dequantize( + lhs, lhs_scales, lhs_zero_points, lhs_dt, DataType::FP32, QuantizationMethod::PER_MATRIX, lhs_h, lhs_w); + lhs = tmp_lhs.data(); + } + + if (data_type_is_quantized(rhs_dt)) { + tmp_rhs = dequantize( + rhs, rhs_scales, rhs_zero_points, rhs_dt, DataType::FP32, QuantizationMethod::PER_ROW, rhs_h, rhs_w); + rhs = tmp_rhs.data(); + } + + KAI_ASSUME(dst_dt == DataType::FP32); + const auto tmp_dst = matmul_any_type(lhs, rhs, m, n, k, lhs_transposed, rhs_transposed); + + return tmp_dst; +} + +} // namespace kai::test diff --git a/test/reference/matmul.hpp b/test/reference/matmul.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7dfca92c033fa8535f072d667b060b9a8f2fd78a --- /dev/null +++ b/test/reference/matmul.hpp @@ -0,0 +1,60 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include + +#include "test/common/data_type.hpp" + +namespace kai::test { + +class DataFormat; + +/// Packs the RHS operand of matrix multiplication. +/// +/// @param[in] data Data buffer. +/// @param[in] scales (Optional) Quantization scales. +/// @param[in] zero_points (Optional) Quantization zero points. +/// @param[in] src_format Data format of the RHS matrix. +/// @param[in] dst_format Data format of the packed RHS matrix. +/// @param[in] height Number of rows. +/// @param[in] width Number of columns. +/// +/// @return The packed RHS matrix. +std::vector matmul_pack_rhs( + const void* data, const void* scales, const void* zero_points, const DataFormat& src_format, + const DataFormat& dst_format, size_t height, size_t width); + +/// Matrix multiplication. +/// +/// @param[in] lhs LHS operand data. +/// @param[in] lhs_scales (Optional) LHS operand quantization scales. +/// @param[in] lhs_zero_points (Optional) LHS operand quantization zero point. +/// @param[in] lhs_dt LHS operand data type. +/// @param[in] dst LHS operand data. +/// @param[in] dst_scales (Optional) LHS operand quantization scales. +/// @param[in] dst_zero_points (Optional) LHS operand quantization zero point. +/// @param[in] dst_dt LHS operand data type. +/// @param[in] dst_dt Output data type. +/// @param[in] m Output height. +/// @param[in] n Output width. +/// @param[in] k Non-transposed LHS width and non-transposed RHS height. +/// @param[in] lhs_transposed `true` if LHS operand is transposed. +/// @param[in] rhs_transposed `true` if RHS operand is transposed. +/// +/// @return The result data buffer. +std::vector matmul( + const void* lhs, const void* lhs_scales, const void* lhs_zero_points, DataType lhs_dt, // + const void* rhs, const void* rhs_scales, const void* rhs_zero_points, DataType rhs_dt, // + DataType dst_dt, // + size_t m, size_t n, size_t k, // + bool lhs_transposed, bool rhs_transposed); + +} // namespace kai::test diff --git a/test/reference/pack.cpp b/test/reference/pack.cpp index 81c7c6640614d0682bb8dcb4b29f2844bc91f706..9287fbbcedbfa137176ec3574943cdbd231c16b5 100644 --- a/test/reference/pack.cpp +++ b/test/reference/pack.cpp @@ -91,10 +91,93 @@ std::vector pack_quant_per_row( return dst; } +/// Packs the matrix with per-row quantized format. +/// +/// The source matrix is per-row quantized with separate quantization scale and zero-points data buffer. +/// The destination data is per-row quantized with blocking and embedded quantization information. +std::vector pack_per_row_qs4( + const void* src, const void* scales, const void* zero_points, size_t height, size_t width, size_t block_height, + size_t block_width, size_t subblock_height, size_t subblock_width) { + // Number of elements in a sub-block in vertical and horizontal axes. + const auto num_element_rows = subblock_height; + const auto num_element_cols = subblock_width; + const auto src_element_row_stride = width / 2; + + // Number of sub-blocks in a block in vertical and horizontal axes. + const auto num_subblock_rows = block_height / subblock_height; + const auto num_subblock_cols = block_width / subblock_width; + const auto src_subblock_col_stride = subblock_width / 4; + const auto src_subblock_row_stride = subblock_height * width / 2; + + // Number of blocks in the matrix in vertical and horizontal axes. + const auto num_block_rows = (height + block_height - 1) / block_height; + const auto num_block_cols = (width + block_width - 1) / block_width; + const auto src_block_col_stride = block_width / 2; + const auto src_block_row_stride = block_height * width / 2; + + const auto dst_block_row_scales_bytes = block_height * sizeof(float); + const auto dst_block_row_zero_points_bytes = block_height * sizeof(int32_t); + const auto dst_block_row_data_bytes = num_block_cols * block_height * block_width / 2; + const auto dst_bytes = + num_block_rows * (dst_block_row_zero_points_bytes + dst_block_row_data_bytes + dst_block_row_scales_bytes); + + std::vector dst; + dst.resize(dst_bytes); + + const auto* src_ptr = reinterpret_cast(src); + const auto* scales_ptr = reinterpret_cast(scales); + const auto* zero_points_ptr = reinterpret_cast(zero_points); + auto* dst_ptr = dst.data(); + + for (size_t block_row = 0; block_row < num_block_rows; ++block_row) { + if (zero_points_ptr != nullptr) { + memcpy(dst_ptr, zero_points_ptr + block_row * block_height, dst_block_row_zero_points_bytes); + } + + dst_ptr += dst_block_row_zero_points_bytes; + + for (size_t block_col = 0; block_col < num_block_cols; ++block_col) { + for (size_t subblock_col = 0; subblock_col < num_subblock_cols; ++subblock_col) { + for (size_t subblock_row = 0; subblock_row < num_subblock_rows; ++subblock_row) { + for (size_t element_col = 0; element_col < num_element_cols / 4; ++element_col) { + for (size_t element_row = 0; element_row < num_element_rows; ++element_row) { + const auto byte_lo = src_ptr[ // + block_row * src_block_row_stride + block_col * src_block_col_stride + + subblock_row * src_subblock_row_stride + subblock_col * src_subblock_col_stride + + element_row * src_element_row_stride + element_col]; + const auto byte_hi = src_ptr[ // + block_row * src_block_row_stride + block_col * src_block_col_stride + + subblock_row * src_subblock_row_stride + subblock_col * src_subblock_col_stride + + element_row * src_element_row_stride + element_col + block_width / 4]; + + const auto packed_byte0 = (byte_lo & 0x0F) | (byte_hi << 4); + const auto packed_byte1 = (byte_lo >> 4) | (byte_hi & 0xF0); + + dst_ptr[0] = packed_byte0; // ^ 0x88; + dst_ptr[1] = packed_byte1; // ^ 0x88; + dst_ptr += 2; + } + } + } + } + } + + if (scales_ptr != nullptr) { + memcpy(dst_ptr, scales_ptr + block_row * block_height, dst_block_row_scales_bytes); + } + dst_ptr += dst_block_row_scales_bytes; + } + + KAI_ASSERT(reinterpret_cast(dst_ptr) - reinterpret_cast(dst.data()) == dst_bytes); + + return dst; +} + } // namespace std::vector pack( - const DataFormat& dst_format, const void* src, const DataFormat& src_format, size_t height, size_t width) { + const DataFormat& dst_format, const void* src, const void* scales, const void* zero_points, + const DataFormat& src_format, size_t height, size_t width) { const auto dst_dt = dst_format.data_type(); const auto dst_qf = dst_format.quantization_format(); const auto src_dt = src_format.data_type(); @@ -105,6 +188,12 @@ std::vector pack( dst_format.zero_point_data_type() == DataType::I32) { return pack_quant_per_row( src, height, width, dst_format.block_height(), dst_format.block_width()); + } else if ( + dst_dt == DataType::QSI4 && src_dt == DataType::QSU4 && dst_format.scale_data_type() == DataType::FP32 && + dst_format.zero_point_data_type() == DataType::I32) { + return pack_per_row_qs4( + src, scales, zero_points, height, width, dst_format.block_height(), dst_format.block_width(), + dst_format.subblock_height(), dst_format.subblock_width()); } } diff --git a/test/reference/pack.hpp b/test/reference/pack.hpp index cfd2540d80f7ecc282e2de46f61ff97d77c96661..43362009f5debff9fa5695ad1f84d8ed2c65852c 100644 --- a/test/reference/pack.hpp +++ b/test/reference/pack.hpp @@ -22,6 +22,7 @@ class DataFormat; /// @param[in] height Number of rows of the source matrix. /// @param[in] width Number of columns of the source matrix. std::vector pack( - const DataFormat& dst_format, const void* src, const DataFormat& src_format, size_t height, size_t width); + const DataFormat& dst_format, const void* src, const void* scales, const void* zero_points, + const DataFormat& src_format, size_t height, size_t width); } // namespace kai::test diff --git a/test/reference/quantize.cpp b/test/reference/quantize.cpp index 98b29af5a34ae111cabd6e60da7b5a6a5b861d55..4b233c817d55668451db88b26a24ea1ef9d7b0aa 100644 --- a/test/reference/quantize.cpp +++ b/test/reference/quantize.cpp @@ -4,11 +4,20 @@ // SPDX-License-Identifier: Apache-2.0 // +#include "test/reference/quantize.hpp" + #include #include #include #include +#include +#include "src/kai_common.h" +#include "test/common/data_type.hpp" +#include "test/common/int4.hpp" +#include "test/common/memory.hpp" +#include "test/common/numeric_limits.hpp" +#include "test/common/type_traits.hpp" #include "test/reference/round.hpp" namespace kai::test { @@ -45,4 +54,64 @@ int8_t quantize_i8_fp32(float value, float scale, int32_t zero_point) { std::numeric_limits::max()); } +namespace { + +template +std::vector dequantize_any_type( + const void* data, const void* scales, const void* zero_points, // + QuantizationMethod method, bool is_asymm, size_t height, size_t width) { + static_assert(is_floating_point); + static_assert(is_integral); + + std::vector dst; + dst.resize(height * width * sizeof(Output)); + KAI_ASSUME(size_in_bits() % 8 == 0); + + auto scale = read_array(scales, 0); + KAI_UNUSED(is_asymm); + KAI_UNUSED(zero_points); + auto zero_point = is_asymm ? read_array(zero_points, 0) : // + -static_cast(numeric_lowest>); + + for (size_t y = 0; y < height; ++y) { + if (method == QuantizationMethod::PER_ROW) { + scale = read_array(scales, y); + + if (is_asymm) { + zero_point = read_array(zero_points, y); + } + } + + for (size_t x = 0; x < width; ++x) { + const ZeroPoint input = read_array(data, y * width + x); + const Scale output = Scale(input - zero_point) * scale; + write_array(dst.data(), y * width + x, output); + } + } + + return dst; +} + +} // namespace + +std::vector dequantize( + const void* data, const void* scales, const void* zero_points, // + DataType src_dt, DataType dst_dt, QuantizationMethod method, // + size_t height, size_t width) { + switch (src_dt) { + case DataType::QSU4: + switch (dst_dt) { + case DataType::FP32: + return dequantize_any_type( + data, scales, zero_points, method, false, height, width); + + default: + KAI_ERROR("Unsupported destination data type!"); + } + + default: + KAI_ERROR("Unsupported source data type!"); + } +} + } // namespace kai::test diff --git a/test/reference/quantize.hpp b/test/reference/quantize.hpp index b073d460050215a4cf1f97ced5f8c3665e2ff4bd..904ac5fc6f357bf2a9a1355b617dfc82f072d285 100644 --- a/test/reference/quantize.hpp +++ b/test/reference/quantize.hpp @@ -6,11 +6,21 @@ #pragma once +#include #include #include +#include + +#include "test/common/data_type.hpp" namespace kai::test { +/// Quantization method. +enum class QuantizationMethod : uint32_t { + PER_MATRIX, ///< Per-matrix, i.e. one quantization scale and zero point for the entire matrix. + PER_ROW, ///< Per-row, i.e. one quantization scale and zero point for each row. +}; + /// Calculates the quantization information for 8-bit signed asymmetric type from the value range. /// /// @param[in] min_value Minimum value. @@ -30,4 +40,21 @@ std::tuple get_qi8_scale_zero_point_from_range(float min_value, /// @return The quantized value. int8_t quantize_i8_fp32(float value, float scale, int32_t zero_point); +/// Dequantizes the matrix to floating-point. +/// +/// @param[in] data Quantized data buffer. +/// @param[in] scales Quantization scales. +/// @param[in] zero_points (Optional) Quantization zero points. +/// @param[in] src_dt Quantized data type. +/// @param[in] dst_dt Dequantized data type. +/// @param[in] method Quantization method. +/// @param[in] height Number of rows. +/// @param[in] width Number of columns. +/// +/// @return The dequantized data buffer. +std::vector dequantize( + const void* data, const void* scales, const void* zero_points, // + DataType src_dt, DataType dst_dt, QuantizationMethod method, // + size_t height, size_t width); + } // namespace kai::test diff --git a/test/reference/reduce.cpp b/test/reference/reduce.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ba14c31fd4c26629492394a2c3f62dcb682f8aa5 --- /dev/null +++ b/test/reference/reduce.cpp @@ -0,0 +1,113 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/reference/reduce.hpp" + +#include +#include +#include + +#include "src/kai_common.h" +#include "test/common/data_format.hpp" +#include "test/common/data_type.hpp" +#include "test/common/int4.hpp" +#include "test/common/memory.hpp" + +namespace kai::test { + +namespace { + +template +T scalar_reduce(T curr_value, T new_value) { + if constexpr (op == ReductionOperator::ADD) { + return curr_value + new_value; + } +} + +template +std::vector reduce_any_op_type(const void* src, size_t height, size_t width, size_t dimension) { + std::vector dst; + + switch (dimension) { + case 0: + dst.resize(height * size_in_bits() / 8); + KAI_ASSUME(height * size_in_bits() % 8 == 0); + + for (size_t y = 0; y < height; ++y) { + Output acc = read_array(src, y * width); + + for (size_t x = 1; x < width; ++x) { + Output value = read_array(src, y * width + x); + acc = scalar_reduce(acc, value); + } + + write_array(dst.data(), y, acc); + } + + break; + + case 1: + dst.resize(width * size_in_bits() / 8); + KAI_ASSUME(width * size_in_bits() % 8 == 0); + + for (size_t x = 0; x < width; ++x) { + Output acc = read_array(src, x); + + for (size_t y = 1; y < height; ++y) { + Output value = read_array(src, y * width + x); + acc = scalar_reduce(acc, value); + } + + write_array(dst.data(), x, acc); + } + + break; + + default: + KAI_ERROR("Only 2D data is supported!"); + } + + return dst; +} + +template +std::vector reduce_any_op( + const void* src, const DataFormat& src_format, size_t height, size_t width, const DataFormat& dst_format, + size_t dimension) { + KAI_ASSUME(src_format.is_raw()); + KAI_ASSUME(dst_format.is_raw()); + KAI_ASSUME(dimension < 2); + KAI_ASSUME(height > 0); + KAI_ASSUME(width > 0); + + const auto src_dt = src_format.data_type(); + const auto dst_dt = dst_format.data_type(); + + switch (src_dt) { + case DataType::QSU4: + switch (dst_dt) { + case DataType::I32: + return reduce_any_op_type(src, height, width, dimension); + break; + + default: + KAI_ERROR("Unsupported data type!"); + } + + default: + KAI_ERROR("Unsupported data type!"); + } +} + +} // namespace + +std::vector reduce_add( + const void* src, const DataFormat& src_format, size_t height, size_t width, const DataFormat& dst_format, + size_t dimension) { + return reduce_any_op(src, src_format, height, width, dst_format, dimension); +} + +} // namespace kai::test diff --git a/test/reference/reduce.hpp b/test/reference/reduce.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f6ba197a0f5bdcf9ef16a35fe4ee4e74cfe50450 --- /dev/null +++ b/test/reference/reduce.hpp @@ -0,0 +1,36 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include + +namespace kai::test { + +class DataFormat; + +/// Reduction operator. +enum class ReductionOperator : uint32_t { + ADD, ///< Addition. +}; + +/// Reduces the matrix value using addition. +/// +/// @param[in] src Input data. +/// @param[in] src_format Input data format. +/// @param[in] height Number of rows. +/// @param[in] width Number of columns. +/// @param[in] dst_foramt Output data format. +/// @param[in] dimension Reduction dimension. +/// +/// @return The reduced matrix. +std::vector reduce_add( + const void* src, const DataFormat& src_format, size_t height, size_t width, const DataFormat& dst_format, + size_t dimension); + +} // namespace kai::test diff --git a/test/tests/matmul_test.cpp b/test/tests/matmul_test.cpp index d8be126a7a5ffd2d0f8a7f8590e2119a0b484cd5..392bb621c83e1ea00f1aaba4e663f2fb8e218a9a 100644 --- a/test/tests/matmul_test.cpp +++ b/test/tests/matmul_test.cpp @@ -4,6 +4,8 @@ // SPDX-License-Identifier: Apache-2.0 // +#include "test/reference/matmul.hpp" + #include #include @@ -14,7 +16,10 @@ #include #include +#include "src/kai_common.h" #include "src/matmul/kai_lhs_quant_pack_qa8dxP4X8_f32.h" +#include "src/matmul/kai_rhs_pack_nxk_qs4cxP_qs4cxS1S0.h" +#include "src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm.h" #include "test/common/compare.hpp" #include "test/common/data_format.hpp" #include "test/common/data_type.hpp" @@ -29,11 +34,16 @@ namespace kai::test { struct MatMulMethod { size_t m0; ///< Block size in M dimension. size_t n0; ///< Block size in N dimension. + size_t k0; ///< Block size in K dimension. + + bool lhs_transposed; ///< LHS matrix is transposed. + bool rhs_transposed; ///< RHS matrix is transposed. DataFormat dst_format; ///< Data format of the destination matrix. DataFormat lhs_format; ///< Data format of the LHS matrix. DataFormat packed_lhs_format; ///< Data format of the packed LHS matrix. DataFormat rhs_format; ///< Data format of the RHS matrix. + DataFormat packed_rhs_format; ///< Data for mat of the packed RHS matrix. /// Gets the offset in bytes of the LHS matrix. /// @@ -60,7 +70,53 @@ struct MatMulMethod { /// @param[in] lhs LHS matrix data buffer. /// @param[in] lhs_row_stride Row stride in bytes of the LHS matrix. /// @param[out] packed_lhs Packed LHS matrix data buffer. - std::function fn_preprocess_lhs; + std::function fn_pack_lhs; + + /// Gets the offset in bytes of the RHS matrix. + std::function fn_get_rhs_offset; + + /// Gets the size in bytes of the packed RHS matrix. + std::function fn_get_packed_rhs_size; + + /// Gets the offset in bytes of the packed RHS matrix. + std::function fn_get_packed_rhs_offset; + + /// Preprocesses the RHS matrix. + std::function + fn_pack_rhs_with_bias_scale; + + std::function + fn_main; + + /// Gets a value indicating whether pre-processing the RHS matrix is needed. + [[nodiscard]] bool is_pack_rhs_needed() const { + return fn_pack_rhs_with_bias_scale != nullptr; + } + + /// Preprocesses the RHS matrix. + void pack_rhs( + size_t n, size_t k, const void* rhs, [[maybe_unused]] size_t rhs_row_stride, const void* bias, + const void* scale, void* packed_rhs) const { + if (fn_pack_rhs_with_bias_scale != nullptr) { + struct kai_rhs_pack_nxk_qs4cxP_qs4cxS1S0_params param { + .lhs_zero_point = 1, .rhs_zero_point = 8, + }; + + const auto sr = packed_rhs_format.block_width() / packed_rhs_format.subblock_width(); + const auto kr = packed_rhs_format.block_width() / sr; + + fn_pack_rhs_with_bias_scale( + 1, n, k, n0, kr, sr, reinterpret_cast(rhs), reinterpret_cast(bias), + reinterpret_cast(scale), packed_rhs, 0, ¶m); + } else { + KAI_ERROR("RHS pre-processing is not supported!"); + } + } }; /// List of supported matrix multiplication methods. @@ -68,17 +124,30 @@ static const std::array matmul_methods = { MatMulMethod{ .m0 = 4, .n0 = 4, + .k0 = 32, + + .lhs_transposed = false, + .rhs_transposed = true, .dst_format = DataFormat(DataType::FP32), .lhs_format = DataFormat(DataType::FP32), .packed_lhs_format = DataFormat(DataType::QI8, 4, 8, DataFormat::QuantizationFormat::PER_ROW, DataType::FP32, DataType::I32), - .rhs_format = DataFormat(DataType::UNKNOWN), // Unused. + .rhs_format = DataFormat(DataType::QSU4), + .packed_rhs_format = DataFormat( + DataType::QSI4, 4, 32, DataFormat::QuantizationFormat::PER_ROW, DataType::FP32, DataType::I32, 1, 16), .fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qa8dxP4X8_f32, .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qa8dxP4X8_f32, .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qa8dxP4X8_f32, - .fn_preprocess_lhs = kai_run_lhs_quant_pack_qa8dxP4X8_f32, + .fn_pack_lhs = kai_run_lhs_quant_pack_qa8dxP4X8_f32, + + .fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_nxk_qs4cxP_qs4cxS1S0, + .fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_nxk_qs4cxP_qs4cxS1S0, + .fn_get_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_nxk_qs4cxP_qs4cxS1S0, + .fn_pack_rhs_with_bias_scale = kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0, + + .fn_main = kai_run_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm, }, }; @@ -92,6 +161,15 @@ struct MatMulShape { /// Matrix multiplication test information. using MatMulTestParams = std::tuple; +/// Prints the test information. +void PrintTo(const MatMulTestParams& param, std::ostream* os) { + const auto& [shape, method_no, portion] = param; + + *os << "m: " << shape.m << ", n: " << shape.n << ", k: " << shape.k << ", method_no: " << method_no + << ", portion: { start_row: " << portion.start_row() << ", start_col: " << portion.start_col() + << ", height: " << portion.height() << ", width: " << portion.width() << "}"; +} + /// Matrix multiplication test fixture. class MatMulTest : public testing::TestWithParam { private: @@ -103,10 +181,14 @@ protected: struct TestData { std::vector lhs; ///< LHS operand. std::vector ref_packed_lhs; ///< Reference packed LHS. + std::vector rhs; ///< RHS operand. + std::vector rhs_scales; ///< RHS per-row quantization scales. + std::vector ref_packed_rhs; ///< Reference packed RHS. + std::vector ref_dst; ///< Reference output. }; /// Gets the test data for the current test case. - TestData& test_data() { + const TestData& test_data() { const auto& [info, method_no, portion] = GetParam(); const TestDataId data_id{info.m, info.n, info.k, method_no}; @@ -120,10 +202,45 @@ protected: // Generates the test data. const auto& method = matmul_methods.at(method_no); - auto lhs = fill_matrix_random(info.m, info.k, method.lhs_format, 0); - auto ref_packed_lhs = pack(method.packed_lhs_format, lhs.data(), method.lhs_format, info.m, info.k); + const auto lhs_h = method.lhs_transposed ? info.k : info.m; + const auto lhs_w = method.lhs_transposed ? info.m : info.k; + auto lhs = fill_matrix_random(lhs_h, lhs_w, method.lhs_format, 0); + auto ref_packed_lhs = + pack(method.packed_lhs_format, lhs.data(), nullptr, nullptr, method.lhs_format, lhs_h, lhs_w); + + const auto rhs_h = method.rhs_transposed ? info.n : info.k; + const auto rhs_w = method.rhs_transposed ? info.k : info.n; + auto rhs = fill_matrix_random(rhs_h, rhs_w, method.rhs_format, 1); + + std::vector rhs_scales; + if (data_type_is_quantized(method.rhs_format.data_type()) && + method.rhs_format.quantization_format() == DataFormat::QuantizationFormat::NONE) { + rhs_scales = fill_matrix_random(rhs_h, 1, DataFormat(DataType::FP32), 2); + } - return _data[data_id] = {std::move(lhs), std::move(ref_packed_lhs)}; + auto packed_rhs = matmul_pack_rhs( + rhs.data(), !rhs_scales.empty() ? rhs_scales.data() : nullptr, nullptr, method.rhs_format, + method.packed_rhs_format, rhs_h, rhs_w); + + KAI_ASSUME(method.lhs_format.is_raw()); + KAI_ASSUME(method.rhs_format.is_raw()); + KAI_ASSUME(method.dst_format.is_raw()); + auto ref_dst = matmul( + lhs.data(), nullptr, nullptr, method.lhs_format.data_type(), // + rhs.data(), rhs_scales.data(), nullptr, method.rhs_format.data_type(), // + method.dst_format.data_type(), // + info.m, info.n, info.k, method.lhs_transposed, method.rhs_transposed); + + const auto& data = _data[data_id] = { + .lhs = std::move(lhs), + .ref_packed_lhs = std::move(ref_packed_lhs), + .rhs = std::move(rhs), + .rhs_scales = std::move(rhs_scales), + .ref_packed_rhs = std::move(packed_rhs), + .ref_dst = std::move(ref_dst), + }; + + return data; } private: @@ -136,50 +253,145 @@ TEST_P(MatMulTest, PackedLhs) { const auto& data = test_data(); const auto& method = matmul_methods.at(method_no); - if (method.fn_preprocess_lhs == nullptr) { + if (method.fn_pack_lhs == nullptr) { GTEST_SKIP(); } - const auto [rect_start_row, rect_start_col, rect_height, rect_width] = portion.compute_portion( - info.m, info.k, method.packed_lhs_format.scheduler_block_height(info.m), - method.packed_lhs_format.scheduler_block_width(info.k)); + const auto lhs_h = method.lhs_transposed ? info.k : info.m; + const auto lhs_w = method.lhs_transposed ? info.m : info.k; + + const auto rect = portion.compute_portion( + lhs_h, lhs_w, method.packed_lhs_format.scheduler_block_height(lhs_h), + method.packed_lhs_format.scheduler_block_width(lhs_w)); - if (rect_height == 0 || rect_width == 0) { + if (rect.height() == 0 || rect.width() == 0) { GTEST_SKIP(); } - const auto ref_lhs_row_stride = method.lhs_format.default_row_stride(info.k); + const auto ref_lhs_row_stride = method.lhs_format.default_row_stride(lhs_w); const auto packed_lhs_size = method.fn_get_packed_lhs_size(info.m, info.k); - const auto ref_packed_lhs_size = method.packed_lhs_format.default_size_in_bytes(info.m, info.k); + const auto ref_packed_lhs_size = method.packed_lhs_format.default_size_in_bytes(lhs_h, lhs_w); ASSERT_EQ(packed_lhs_size, ref_packed_lhs_size); - const auto lhs_offset = method.fn_get_lhs_offset(rect_start_row, ref_lhs_row_stride); - const auto ref_lhs_offset = method.lhs_format.default_offset_in_bytes(rect_start_row, 0, info.k); + const auto lhs_offset = method.fn_get_lhs_offset(rect.start_row(), ref_lhs_row_stride); + const auto ref_lhs_offset = method.lhs_format.default_offset_in_bytes(rect.start_row(), 0, lhs_w); ASSERT_EQ(lhs_offset, ref_lhs_offset); - const auto packed_lhs_offset = method.fn_get_packed_lhs_offset(rect_start_row, info.k); - const auto ref_packed_lhs_offset = method.packed_lhs_format.default_offset_in_bytes(rect_start_row, 0, info.k); + const auto packed_lhs_offset = method.fn_get_packed_lhs_offset(rect.start_row(), info.k); + const auto ref_packed_lhs_offset = method.packed_lhs_format.default_offset_in_bytes(rect.start_row(), 0, lhs_w); ASSERT_EQ(packed_lhs_offset, ref_packed_lhs_offset); std::vector packed_lhs; packed_lhs.resize(packed_lhs_size); - method.fn_preprocess_lhs( - rect_height, rect_width, data.lhs.data() + lhs_offset, ref_lhs_row_stride, + method.fn_pack_lhs( + rect.height(), rect.width(), data.lhs.data() + lhs_offset, ref_lhs_row_stride, packed_lhs.data() + packed_lhs_offset); const auto success = compare( - packed_lhs.data(), data.ref_packed_lhs.data(), method.packed_lhs_format, info.m, info.k, portion, + packed_lhs.data(), data.ref_packed_lhs.data(), method.packed_lhs_format, lhs_h, lhs_w, rect, + DefaultMismatchHandler(0, 0.0001, 0, 0.001)); + ASSERT_TRUE(success); +} + +/// Tests the RHS packing kernel. +TEST_P(MatMulTest, PackedRhs) { + const auto& [info, method_no, portion] = GetParam(); + const auto& data = test_data(); + const auto& method = matmul_methods.at(method_no); + + if (!method.is_pack_rhs_needed()) { + GTEST_SKIP(); + } + + const auto rhs_h = method.rhs_transposed ? info.n : info.k; + const auto rhs_w = method.rhs_transposed ? info.k : info.n; + + const auto rect = portion.compute_portion( + rhs_h, rhs_w, method.packed_rhs_format.scheduler_block_height(rhs_h), + method.packed_rhs_format.scheduler_block_width(rhs_w)); + + if (rect.height() == 0 || rect.width() == 0) { + GTEST_SKIP(); + } + + const auto ref_rhs_row_stride = method.rhs_format.default_row_stride(rhs_w); + + const auto rhs_offset = method.fn_get_rhs_offset(rect.start_row(), ref_rhs_row_stride); + const auto ref_rhs_offset = method.rhs_format.default_offset_in_bytes(rect.start_row(), rect.start_col(), rhs_w); + ASSERT_EQ(rhs_offset, ref_rhs_offset); + + const auto packed_rhs_size = method.fn_get_packed_rhs_size( + rhs_h, rhs_w, method.packed_rhs_format.block_height(), method.packed_rhs_format.block_width()); + const auto ref_packed_rhs_size = method.packed_rhs_format.default_size_in_bytes(rhs_h, rhs_w); + ASSERT_EQ(packed_rhs_size, ref_packed_rhs_size); + + const auto packed_rhs_offset = method.fn_get_packed_rhs_offset( + rect.start_row(), rhs_w, method.packed_rhs_format.block_height(), method.packed_rhs_format.block_width()); + const auto ref_packed_rhs_offset = + method.packed_rhs_format.default_offset_in_bytes(rect.start_row(), rect.start_col(), rhs_w); + ASSERT_EQ(packed_rhs_offset, ref_packed_rhs_offset); + + const auto ref_rhs_scales_offset = + rect.start_row() * data_type_size_in_bits(method.packed_rhs_format.scale_data_type()) / 8; + + std::vector packed_rhs; + packed_rhs.resize(packed_rhs_size); + + method.pack_rhs( + rect.height(), rect.width(), data.rhs.data() + rhs_offset, ref_rhs_row_stride, nullptr, + !data.rhs_scales.empty() ? data.rhs_scales.data() + ref_rhs_scales_offset : nullptr, + packed_rhs.data() + packed_rhs_offset); + + const auto success = compare( + packed_rhs.data(), data.ref_packed_rhs.data(), method.packed_rhs_format, rhs_h, rhs_w, rect, DefaultMismatchHandler(0, 0.0001, 0, 0.001)); ASSERT_TRUE(success); } +/// Tests the output. +TEST_P(MatMulTest, Output) { + const auto& [info, method_no, portion] = GetParam(); + const auto& data = test_data(); + const auto& method = matmul_methods.at(method_no); + + const auto rect = portion.compute_portion(info.m, info.n, method.m0, method.n0); + + if (rect.height() == 0 || rect.width() == 0) { + GTEST_SKIP(); + } + + const auto ref_dst_row_stride = method.dst_format.default_row_stride(info.n); + const auto ref_dst_col_stride = data_type_size_in_bits(method.dst_format.data_type()) / 8; + + const auto ref_packed_lhs_offset = method.packed_lhs_format.default_offset_in_bytes( + method.lhs_transposed ? 0 : rect.start_row(), method.lhs_transposed ? rect.start_row() : 0, + method.lhs_transposed ? info.m : info.k); + const auto ref_packed_rhs_offset = method.packed_rhs_format.default_offset_in_bytes( + method.rhs_transposed ? rect.start_col() : 0, method.rhs_transposed ? 0 : rect.start_col(), + method.rhs_transposed ? info.k : info.n); + const auto ref_dst_offset = method.dst_format.default_offset_in_bytes(rect.start_row(), rect.start_col(), info.n); + + std::vector dst; + dst.resize(method.dst_format.default_size_in_bytes(info.m, info.n)); + + method.fn_main( + rect.height(), rect.width(), info.k, data.ref_packed_lhs.data() + ref_packed_lhs_offset, + data.ref_packed_rhs.data() + ref_packed_rhs_offset, reinterpret_cast(dst.data() + ref_dst_offset), + ref_dst_row_stride, ref_dst_col_stride, -FLT_MAX, FLT_MAX); + + const auto success = compare( + dst.data(), data.ref_dst.data(), method.dst_format, info.m, info.n, rect, + DefaultMismatchHandler(0, 0.1, 0, 0.05)); + ASSERT_TRUE(success); +} + INSTANTIATE_TEST_SUITE_P( MatMul, MatMulTest, testing::Combine( testing::Values( MatMulShape{4, 4, 32}, // - MatMulShape{12, 16, 48}), + MatMulShape{12, 16, 64}), testing::Range(0, matmul_methods.size()), testing::Values( MatrixPortion(0, 0, 1, 1), // Full matrix.