From 133a04448a87e74ca22d0101a3cd947fef7d918c Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Tue, 30 Apr 2024 12:42:29 +0100 Subject: [PATCH 01/11] Create test framework * Add common concepts: - Data type. - Data format. * Add custom data types: - 4-bit signed integer. - 4-bit unsigned integer. * Add generic type utilities: - Type traits. - Numeric limits. - Memory access. * Add matrix printer for arbitrary data format. * Add reference implementation of rounding functions. Signed-off-by: Viet-Hoa Do --- .clang-tidy | 5 ++ CMakeLists.txt | 34 ++++++-- src/kai_common.h | 42 ++++++++++ test/common/data_format.cpp | 144 +++++++++++++++++++++++++++++++++ test/common/data_format.hpp | 126 +++++++++++++++++++++++++++++ test/common/data_type.cpp | 81 +++++++++++++++++++ test/common/data_type.hpp | 110 +++++++++++++++++++++++++ test/common/int4.cpp | 98 ++++++++++++++++++++++ test/common/int4.hpp | 112 +++++++++++++++++++++++++ test/common/logging.hpp | 68 ++++++++++++++++ test/common/memory.hpp | 77 ++++++++++++++++++ test/common/numeric_limits.hpp | 39 +++++++++ test/common/printer.cpp | 112 +++++++++++++++++++++++++ test/common/printer.hpp | 28 +++++++ test/common/type_traits.hpp | 79 ++++++++++++++++++ test/reference/round.cpp | 36 +++++++++ test/reference/round.hpp | 62 ++++++++++++++ 17 files changed, 1246 insertions(+), 7 deletions(-) create mode 100644 src/kai_common.h create mode 100644 test/common/data_format.cpp create mode 100644 test/common/data_format.hpp create mode 100644 test/common/data_type.cpp create mode 100644 test/common/data_type.hpp create mode 100644 test/common/int4.cpp create mode 100644 test/common/int4.hpp create mode 100644 test/common/logging.hpp create mode 100644 test/common/memory.hpp create mode 100644 test/common/numeric_limits.hpp create mode 100644 test/common/printer.cpp create mode 100644 test/common/printer.hpp create mode 100644 test/common/type_traits.hpp create mode 100644 test/reference/round.cpp create mode 100644 test/reference/round.hpp diff --git a/.clang-tidy b/.clang-tidy index 7cff2942..bebfc13a 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -22,5 +22,10 @@ readability-*, -readability-identifier-length, -readability-magic-numbers, -readability-function-cognitive-complexity, +-cppcoreguidelines-pro-type-reinterpret-cast, +-cppcoreguidelines-avoid-magic-numbers, +-readability-simplify-boolean-expr, +-bugprone-easily-swappable-parameters, +-cppcoreguidelines-pro-bounds-pointer-arithmetic ' ... diff --git a/CMakeLists.txt b/CMakeLists.txt index d45d5070..b6d9662c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -31,9 +31,7 @@ endif() set(KLEIDIAI_WARNING_FLAGS "-Wall" - "-Wctor-dtor-privacy" "-Wdisabled-optimization" - "-Weffc++" "-Werror" "-Wextra" "-Wformat-security" @@ -42,19 +40,41 @@ set(KLEIDIAI_WARNING_FLAGS "-Wno-ignored-attributes" "-Wno-misleading-indentation" "-Wno-overlength-strings" - "-Woverloaded-virtual" - "-Wsign-promo" "-Wstrict-overflow=2" "-Wswitch-default" ) +set(KLEIDIAI_WARNING_FLAGS_CXX + "-Wctor-dtor-privacy" + "-Weffc++" + "-Woverloaded-virtual" + "-Wsign-promo" +) + if(KLEIDIAI_BUILD_TESTS) enable_testing() include(GoogleTest) - add_executable(kleidiai_test test/sample.cpp) - target_compile_options(kleidiai_test PRIVATE ${KLEIDIAI_WARNING_FLAGS}) - target_link_libraries(kleidiai_test PRIVATE GTest::gtest_main) + add_executable(kleidiai_test + test/common/data_type.cpp + test/common/data_format.cpp + test/common/printer.cpp + test/common/int4.cpp + test/reference/round.cpp + ) + + target_include_directories(kleidiai_test + PRIVATE . + ) + + target_compile_options(kleidiai_test + PRIVATE ${KLEIDIAI_WARNING_FLAGS} + PRIVATE ${KLEIDIAI_WARNING_FLAGS_CXX} + ) + + target_link_libraries(kleidiai_test + PRIVATE GTest::gtest_main + ) # Cross-compiling is a common use case which creates a conflict if DISCOVERY_MODE is set to POST_BUILD (by default) # since the host platform does not match the target. Setting the mode to PRE_TEST avoids this conflict. diff --git a/src/kai_common.h b/src/kai_common.h new file mode 100644 index 00000000..b44dad5b --- /dev/null +++ b/src/kai_common.h @@ -0,0 +1,42 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +// NOLINTBEGIN(cppcoreguidelines-avoid-do-while,cppcoreguidelines-pro-type-vararg,cert-err33-c) +// +// * cppcoreguidelines-avoid-do-while: do-while is necessary for macros. +// * cppcoreguidelines-pro-type-vararg: use of variadic arguments in fprintf is expected. +// * cert-err33-c: checking the output of fflush and fprintf is not necessary for error reporting. +#define KAI_ERROR(msg) \ + do { \ + fflush(stdout); \ + fprintf(stderr, "%s", msg); \ + exit(EXIT_FAILURE); \ + } while (0) + +#define KAI_ASSERT_MSG(cond, msg) \ + do { \ + if (!(cond)) { \ + KAI_ERROR(msg); \ + } \ + } while (0) +// NOLINTEND(cppcoreguidelines-avoid-do-while,cppcoreguidelines-pro-type-vararg,cert-err33-c) + +#define KAI_ASSERT(cond) KAI_ASSERT_MSG(cond, #cond) + +#define KAI_ASSERT_IF_MSG(precond, cond, msg) KAI_ASSERT_MSG(!(precond) || (cond), msg) +#define KAI_ASSERT_IF(precond, cond) KAI_ASSERT_IF_MSG(precond, cond, #precond " |-> " #cond) + +#define KAI_ASSUME_MSG KAI_ASSERT_MSG +#define KAI_ASSUME KAI_ASSERT +#define KAI_ASSUME_IF_MSG KAI_ASSERT_IF_MSG +#define KAI_ASSUME_IF KAI_ASSERT_IF + +#define KAI_UNUSED(x) (void)(x) diff --git a/test/common/data_format.cpp b/test/common/data_format.cpp new file mode 100644 index 00000000..1673b405 --- /dev/null +++ b/test/common/data_format.cpp @@ -0,0 +1,144 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/common/data_format.hpp" + +#include +#include + +#include "src/kai_common.h" +#include "test/common/data_type.hpp" +#include "test/reference/round.hpp" + +namespace kai::test { + +DataFormat::DataFormat( + DataType data_type, size_t block_height, size_t block_width, QuantizationFormat quant_format, DataType scale_dt, + DataType zero_point_dt, size_t subblock_height, size_t subblock_width) noexcept : + _data_type(data_type), + _quant_format(quant_format), + _scale_dt(scale_dt), + _zero_point_dt(zero_point_dt), + _block_height(block_height), + _block_width(block_width), + _subblock_height(subblock_height), + _subblock_width(subblock_width) { +} + +bool DataFormat::operator==(const DataFormat& rhs) const { + return _data_type == rhs._data_type && _quant_format == rhs._quant_format && _scale_dt == rhs._scale_dt && + _zero_point_dt == rhs._zero_point_dt && _block_height == rhs._block_height && _block_width == rhs._block_width; +} + +bool DataFormat::operator!=(const DataFormat& rhs) const { + return !(*this == rhs); +} + +DataType DataFormat::data_type() const { + return _data_type; +} + +DataFormat::QuantizationFormat DataFormat::quantization_format() const { + return _quant_format; +} + +DataType DataFormat::scale_data_type() const { + return _scale_dt; +} + +DataType DataFormat::zero_point_data_type() const { + return _zero_point_dt; +} + +bool DataFormat::is_raw() const { + return _quant_format == QuantizationFormat::NONE && // + _block_height == 0 && _block_width == 0 && _subblock_height == 0 && _subblock_width == 0; +} + +size_t DataFormat::block_height() const { + return _block_height; +} + +size_t DataFormat::block_width() const { + return _block_width; +} + +size_t DataFormat::subblock_height() const { + return _subblock_height; +} + +size_t DataFormat::subblock_width() const { + return _subblock_width; +} + +size_t DataFormat::scheduler_block_height([[maybe_unused]] size_t full_height) const { + switch (_quant_format) { + case QuantizationFormat::NONE: + return _block_height > 0 ? _block_height : 1; + + case QuantizationFormat::PER_ROW: + return _block_height; + + default: + KAI_ERROR("Unsupported quantization packing format!"); + } +} + +size_t DataFormat::scheduler_block_width(size_t full_width) const { + switch (_quant_format) { + case QuantizationFormat::NONE: + return _block_width > 0 ? _block_width : 1; + + case QuantizationFormat::PER_ROW: + return full_width; + + default: + KAI_ERROR("Unsupported quantization packing format!"); + } +} + +uintptr_t DataFormat::default_row_stride(size_t width) const { + const auto padded_width = round_up_multiple(width, _block_width > 0 ? _block_width : 1); + + switch (_quant_format) { + case QuantizationFormat::NONE: + return padded_width * data_type_size_in_bits(_data_type) / 8; + + case QuantizationFormat::PER_ROW: + return _block_height * data_type_size_in_bits(_zero_point_dt) / 8 + // + _block_height * padded_width * data_type_size_in_bits(_data_type) / 8 + // + _block_height * data_type_size_in_bits(_scale_dt) / 8; + + default: + KAI_ERROR("Unsupported quantization packing format!"); + } +} + +uintptr_t DataFormat::default_offset_in_bytes(size_t row, size_t col, size_t width) const { + const auto row_stride = default_row_stride(width); + + KAI_ASSERT(col % scheduler_block_width(width) == 0); + + switch (_quant_format) { + case QuantizationFormat::NONE: + return row * row_stride + col * data_type_size_in_bits(_data_type) / 8; + + case QuantizationFormat::PER_ROW: + KAI_ASSERT(row % _block_height == 0); + KAI_ASSERT(col == 0); + return (row / _block_height) * row_stride + col * data_type_size_in_bits(_data_type) / 8; + + default: + KAI_ERROR("Unsupported quantization packing format!"); + } +} + +size_t DataFormat::default_size_in_bytes(size_t height, size_t width) const { + const auto num_rows = _block_height > 0 ? (height + _block_height - 1) / _block_height : height; + return num_rows * default_row_stride(width); +} + +} // namespace kai::test diff --git a/test/common/data_format.hpp b/test/common/data_format.hpp new file mode 100644 index 00000000..3aa19f71 --- /dev/null +++ b/test/common/data_format.hpp @@ -0,0 +1,126 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#include "test/common/data_type.hpp" + +namespace kai::test { + +/// Data format. +class DataFormat { +public: + /// Quantization packing format. + enum class QuantizationFormat : uint32_t { + NONE, ///< No quantization information is included. + PER_ROW, ///< Per-row quantization. + }; + + /// Creates a new data format. + /// + /// @param[in] data_type Data type of data value. + /// @param[in] block_height Block height. + /// @param[in] block_width Block width. + /// @param[in] quant_format Quantization packing format. + /// @param[in] scale_dt Data type of scale value. + /// @param[in] zero_point_dt Data type of zero point value. + /// @param[in] subblock_height Sub-block height. + /// @param[in] subblock_width Sub-block width. + DataFormat( + DataType data_type, size_t block_height = 0, size_t block_width = 0, + QuantizationFormat quant_format = QuantizationFormat::NONE, DataType scale_dt = DataType::UNKNOWN, + DataType zero_point_dt = DataType::UNKNOWN, size_t subblock_height = 0, size_t subblock_width = 0) noexcept; + + /// Equality operator. + [[nodiscard]] bool operator==(const DataFormat& rhs) const; + + /// Unequality operator. + [[nodiscard]] bool operator!=(const DataFormat& rhs) const; + + /// Gets the quantization packing format. + [[nodiscard]] QuantizationFormat quantization_format() const; + + /// Gets the data type of data value. + [[nodiscard]] DataType data_type() const; + + /// Gets the data type of scale value. + [[nodiscard]] DataType scale_data_type() const; + + /// Gets the data type of zero point value. + [[nodiscard]] DataType zero_point_data_type() const; + + /// Gets a value indicating whether this format has no blocking or packed quantization information. + [[nodiscard]] bool is_raw() const; + + /// Gets the block height. + [[nodiscard]] size_t block_height() const; + + /// Gets the block width. + [[nodiscard]] size_t block_width() const; + + /// Gets the sub-block height. + [[nodiscard]] size_t subblock_height() const; + + /// Gets the sub-block width. + [[nodiscard]] size_t subblock_width() const; + + /// Gets the scheduling block height. + /// + /// @param[in] full_height Height of the full matrix. + /// + /// @return The block height for scheduling purpose. + [[nodiscard]] size_t scheduler_block_height(size_t full_height) const; + + /// Gets the scheduling block width. + /// + /// @param[in] full_width Width of the full matrix. + /// + /// @return The block width for scheduling purpose. + [[nodiscard]] size_t scheduler_block_width(size_t full_width) const; + + /// Gets the row stride in bytes given the data is stored continuously without any gap in the memory. + /// + /// In case of per-row quantization, the row stride is the number of bytes from one row group + /// to the next. One row group consists of `block_height` rows. + /// + /// @param[in] width Width of the full matrix. + /// + /// @return The default row stride in bytes of the matrix. + [[nodiscard]] uintptr_t default_row_stride(size_t width) const; + + /// Gets the offsets in bytes in the data buffer given the data is stored continuously + /// without any gap in the memory. + /// + /// @param[in] row Row coordinate. + /// @param[in] col Colum coordinate. + /// @param[in] width Width of the full matrix. + /// + /// @return The default offset in bytes. + [[nodiscard]] uintptr_t default_offset_in_bytes(size_t row, size_t col, size_t width) const; + + /// Gets the size in bytes of the matrix given the data is stored continuously without any gap in the memory. + /// + /// @param[in] height Height of the full matrix. + /// @param[in] width Width of the full matrix. + /// + /// @return The size in bytes of the matrix. + [[nodiscard]] size_t default_size_in_bytes(size_t height, size_t width) const; + +private: + DataType _data_type; + QuantizationFormat _quant_format; + DataType _scale_dt; + DataType _zero_point_dt; + size_t _block_height; + size_t _block_width; + size_t _subblock_height; + size_t _subblock_width; +}; + +} // namespace kai::test diff --git a/test/common/data_type.cpp b/test/common/data_type.cpp new file mode 100644 index 00000000..79cc7e7c --- /dev/null +++ b/test/common/data_type.cpp @@ -0,0 +1,81 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/common/data_type.hpp" + +#include +#include + +#include "src/kai_common.h" + +namespace kai::test { + +namespace { + +bool has_i(DataType dt) { + return (static_cast(dt) & (1 << 15)) != 0; +} + +bool has_s(DataType dt) { + return (static_cast(dt) & (1 << 14)) != 0; +} + +bool has_q(DataType dt) { + return (static_cast(dt) & (1 << 13)) != 0; +} + +bool has_a(DataType dt) { + return (static_cast(dt) & (1 << 12)) != 0; +} + +size_t bits(DataType dt) { + return static_cast(dt) & 0xFF; +} + +} // namespace + +size_t data_type_size_in_bits(DataType dt) { + return bits(dt); +} + +bool data_type_is_integral(DataType dt) { + return has_i(dt); +} + +bool data_type_is_float(DataType dt) { + KAI_ASSERT(data_type_is_signed(dt)); + return !data_type_is_integral(dt); +} + +bool data_type_is_float_fp(DataType dt) { + KAI_ASSERT(data_type_is_float(dt)); + return !has_q(dt); +} + +bool data_type_is_float_bf(DataType dt) { + KAI_ASSERT(data_type_is_float(dt)); + return has_q(dt); +} + +bool data_type_is_signed(DataType dt) { + if (!has_s(dt)) { + KAI_ASSERT(data_type_is_integral(dt)); + } + + return has_s(dt); +} + +bool data_type_is_quantized(DataType dt) { + KAI_ASSERT_IF(has_q(dt), data_type_is_integral(dt)); + return has_q(dt); +} + +bool data_type_is_quantized_asymm(DataType dt) { + KAI_ASSERT_IF(has_a(dt), data_type_is_quantized(dt)); + return has_a(dt); +} + +} // namespace kai::test diff --git a/test/common/data_type.hpp b/test/common/data_type.hpp new file mode 100644 index 00000000..8fbdcf9d --- /dev/null +++ b/test/common/data_type.hpp @@ -0,0 +1,110 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +namespace kai::test { + +/// Data type. +enum class DataType : uint16_t { + // Encoding: + // + // 15 0 + // +---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+ + // | i | s | q | a | RES0 | bits | + // +---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+ + // + // (RES0: reserved, filled with 0s) + // + // Fields: + // + // * i: integer (1) or floating-point (0). + // * s: signed (1) or unsigned (0). + // * q: + // - Integer (i): quantized (1) or non-quantized (0). + // - Floating-point (!i): brain (1) or binary (0). + // * a: + // - Quantized (i && q): asymmetric (1) or symmetric (0). + // - Otherwise: RES0. + // * bits: size in bits. + + UNKNOWN = 0, ///< No data. + + FP32 = 0b0'1'0'0'0000'00100000, ///< Single-precision floating-point. + FP16 = 0b0'1'0'0'0000'00010000, ///< Half-precision floating-point. + + I32 = 0b1'1'0'0'0000'00100000, ///< 32-bit signed integer. + + QI8 = 0b1'1'1'1'0000'00001000, ///< 8-bit asymmetric quantized. + + QSU4 = 0b1'0'1'0'0000'00000100, ///< 4-bit unsigned symmetric quantized. + QSI4 = 0b1'1'1'0'0000'00000100, ///< 4-bit signed symmetric quantized. +}; + +/// Gets the size in bits of the specified data type. +/// +/// @param[in] dt The data type. +/// +/// @return The size in bits. +[[nodiscard]] size_t data_type_size_in_bits(DataType dt); + +/// Gets a value indicating whether the data type is integral. +/// +/// @param[in] dt The data type. +/// +/// @return `true` if the data type is integral. +[[nodiscard]] bool data_type_is_integral(DataType dt); + +/// Gets a value indicating whether the data type is floating-point. +/// +/// @param[in] dt The data type. +/// +/// @return `true` if the data type is floating-point. +[[nodiscard]] bool data_type_is_float(DataType dt); + +/// Gets a value indicating whether the data type is binary floating-point. +/// +/// Binary floating point are `half`, `float`, `double`. +/// +/// @param[in] dt The data type. +/// +/// @return `true` if the data type is binary floating-point. +[[nodiscard]] bool data_type_is_float_fp(DataType dt); + +/// Gets a value indicating whether the data type is brain floating-point. +/// +/// Binary floating point are `bfloat16`. +/// +/// @param[in] dt The data type. +/// +/// @return `true` if the data type is brain floating-point. +[[nodiscard]] bool data_type_is_float_bf(DataType dt); + +/// Gets a value indicating whether the data type is signed. +/// +/// @param[in] dt The data type. +/// +/// @return `true` if the data type is signed. +[[nodiscard]] bool data_type_is_signed(DataType dt); + +/// Gets a value indicating whether the data type is quantized. +/// +/// @param[in] dt The data type. +/// +/// @return `true` if the data type is quantized. +[[nodiscard]] bool data_type_is_quantized(DataType dt); + +/// Gets a value indicating whether the data type is asymmetric quantized. +/// +/// @param[in] dt The data type. +/// +/// @return `true` if the data type is asymmetric quantized. +[[nodiscard]] bool data_type_is_quantized_asymm(DataType dt); + +} // namespace kai::test diff --git a/test/common/int4.cpp b/test/common/int4.cpp new file mode 100644 index 00000000..bc9563b1 --- /dev/null +++ b/test/common/int4.cpp @@ -0,0 +1,98 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/common/int4.hpp" + +#include +#include + +namespace kai::test { + +UInt4& UInt4::operator=(uint8_t value) { + _value = value; + return *this; +} + +UInt4::operator int32_t() const { + return _value; +} + +UInt4::operator float() const { + return _value; +} + +UInt4 UInt4::operator+(UInt4 rhs) const { + return UInt4(_value + rhs._value); +} + +UInt4 UInt4::operator-(UInt4 rhs) const { + return UInt4(_value - rhs._value); +} + +UInt4 UInt4::operator*(UInt4 rhs) const { + return UInt4(_value * rhs._value); +} + +UInt4 UInt4::operator/(UInt4 rhs) const { + return UInt4(_value / rhs._value); +} + +uint8_t UInt4::pack_u8(UInt4 low, UInt4 high) { + return (low._value & 0x0F) | (high._value << 4); +} + +std::tuple UInt4::unpack_u8(uint8_t value) { + const uint8_t low = value & 0x0F; + const uint8_t high = value >> 4; + + return {UInt4(low), UInt4(high)}; +} + +// ===================================================================================================================== + +Int4& Int4::operator=(int8_t value) { + _value = value; + return *this; +} + +Int4::operator int32_t() const { + return _value; +} + +Int4::operator float() const { + return _value; +} + +Int4 Int4::operator+(Int4 rhs) const { + return Int4(static_cast(_value + rhs._value)); +} + +Int4 Int4::operator-(Int4 rhs) const { + return Int4(static_cast(_value - rhs._value)); +} + +Int4 Int4::operator*(Int4 rhs) const { + return Int4(static_cast(_value * rhs._value)); +} + +Int4 Int4::operator/(Int4 rhs) const { + return Int4(static_cast(_value / rhs._value)); +} + +uint8_t Int4::pack_u8(Int4 low, Int4 high) { + return (low._value & 0x0F) | (high._value << 4); +} + +std::tuple Int4::unpack_u8(uint8_t value) { + // NOLINTBEGIN(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + const int8_t low = static_cast(value << 4) >> 4; + const int8_t high = static_cast(value) >> 4; + // NOLINTEND(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + + return {Int4(low), Int4(high)}; +} + +} // namespace kai::test diff --git a/test/common/int4.hpp b/test/common/int4.hpp new file mode 100644 index 00000000..169c5817 --- /dev/null +++ b/test/common/int4.hpp @@ -0,0 +1,112 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +namespace kai::test { + +/// 4-bit unsigned integer. +class UInt4 { +public: + /// Creates a new 4-bit unsigned integer value. + /// + /// @param[in] value Value. + constexpr explicit UInt4(uint8_t value) : _value(value) { + } + + /// Assignment operator. + UInt4& operator=(uint8_t value); + + /// Conversion operator. + operator int32_t() const; + + /// Conversion operator. + operator float() const; + + /// Addition operator. + [[nodiscard]] UInt4 operator+(UInt4 rhs) const; + + /// Subtraction operator. + [[nodiscard]] UInt4 operator-(UInt4 rhs) const; + + /// Multiplication operator. + [[nodiscard]] UInt4 operator*(UInt4 rhs) const; + + /// Division operator. + [[nodiscard]] UInt4 operator/(UInt4 rhs) const; + + /// Packs two 4-bit unsigned integer values into one byte. + /// + /// @param[in] low Low nibble. + /// @param[in] high High nibble. + /// + /// @return The packed byte. + [[nodiscard]] static uint8_t pack_u8(UInt4 low, UInt4 high); + + /// Unpacks one byte to two 4-bit unsigned integer values. + /// + /// @param[in] value 8-bit packed value. + /// + /// @return The low and high nibbles. + [[nodiscard]] static std::tuple unpack_u8(uint8_t value); + +private: + uint8_t _value; +}; + +/// 4-bit signed integer. +class Int4 { +public: + /// Creates a new 4-bit signed integer value. + /// + /// @param[in] value Value. + constexpr explicit Int4(int8_t value) : _value(value) { + } + + /// Assignment operator. + Int4& operator=(int8_t value); + + /// Conversion operator. + explicit operator int32_t() const; + + /// Conversion operator. + explicit operator float() const; + + /// Addition operator. + [[nodiscard]] Int4 operator+(Int4 rhs) const; + + /// Subtraction operator. + [[nodiscard]] Int4 operator-(Int4 rhs) const; + + /// Multiplication operator. + [[nodiscard]] Int4 operator*(Int4 rhs) const; + + /// Division operator. + [[nodiscard]] Int4 operator/(Int4 rhs) const; + + /// Packs two 4-bit signed integer values into one byte. + /// + /// @param[in] low Low nibble. + /// @param[in] high High nibble. + /// + /// @return The packed byte. + [[nodiscard]] static uint8_t pack_u8(Int4 low, Int4 high); + + /// Unpacks one byte to two 4-bit signed integer values. + /// + /// @param[in] value 8-bit packed value. + /// + /// @return The low and high nibbles. + [[nodiscard]] static std::tuple unpack_u8(uint8_t value); + +private: + int8_t _value; +}; + +} // namespace kai::test diff --git a/test/common/logging.hpp b/test/common/logging.hpp new file mode 100644 index 00000000..ff6130a0 --- /dev/null +++ b/test/common/logging.hpp @@ -0,0 +1,68 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include + +#include "test/common/int4.hpp" + +#define KAI_LOGE(...) kai::test::detail::log("ERROR", __VA_ARGS__) + +namespace kai::test::detail { + +/// Prints the specified value to standard error. +/// +/// @tparam T Data type. +/// +/// @param[in] value Value to be printed out. +template +void write_log_content(T&& value) { + using TT = std::decay_t; + + if constexpr (std::is_same_v) { + std::cerr << static_cast(value); + } else if constexpr (std::is_same_v) { + std::cerr << static_cast(value); + } else if constexpr (std::is_same_v) { + std::cerr << static_cast(value); + } else if constexpr (std::is_same_v) { + std::cerr << static_cast(value); + } else { + std::cerr << value; + } +} + +/// Prints the specified values to standard error. +/// +/// @tparam T Data type of the first value. +/// @tparam Ts Data types of the subsequent values. +/// +/// @param[in] value First value to be printed out. +/// @param[in] others Subsequent values to be printed out. +template +void write_log_content(T&& value, Ts&&... others) { + write_log_content(std::forward(value)); + write_log_content(std::forward(others)...); +} + +/// Prints the log to standard error. +/// +/// @tparam Ts Data types of values to be printed out. +/// +/// @param[in] level Severity level. +/// @param[in] args Values to be printed out. +template +void log(std::string_view level, Ts&&... args) { + std::cerr << level << " | "; + write_log_content(std::forward(args)...); + std::cerr << "\n"; +} + +} // namespace kai::test::detail diff --git a/test/common/memory.hpp b/test/common/memory.hpp new file mode 100644 index 00000000..bf5fbb01 --- /dev/null +++ b/test/common/memory.hpp @@ -0,0 +1,77 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#include "test/common/int4.hpp" + +namespace kai::test { + +/// The size in bits of type `T`. +template +inline constexpr size_t size_in_bits = sizeof(T) * 8; + +/// The size in bits of type `T`. +template <> +inline constexpr size_t size_in_bits = 4; + +/// The size in bits of type `T`. +template <> +inline constexpr size_t size_in_bits = 4; + +/// Reads the array at the specified index. +/// +/// @param[in] array Data buffer. +/// @param[in] index Array index. +/// +/// @return The array value at the specified index. +template +T read_array(const void* array, size_t index) { + if constexpr (std::is_same_v) { + const auto [lo, hi] = UInt4::unpack_u8(reinterpret_cast(array)[index / 2]); + return index % 2 == 0 ? lo : hi; + } else if constexpr (std::is_same_v) { + const auto [lo, hi] = Int4::unpack_u8(reinterpret_cast(array)[index / 2]); + return index % 2 == 0 ? lo : hi; + } else { + return reinterpret_cast(array)[index]; + } +} + +/// Writes the specified value to the array. +/// +/// @param[in] array Data buffer. +/// @param[in] index Array index. +/// @param[in] value Value to be stored. +template +void write_array(void* array, size_t index, T value) { + if constexpr (std::is_same_v) { + auto* arr_value = reinterpret_cast(array) + index / 2; + const auto [lo, hi] = UInt4::unpack_u8(*arr_value); + + if (index % 2 == 0) { + *arr_value = UInt4::pack_u8(value, hi); + } else { + *arr_value = UInt4::pack_u8(lo, value); + } + } else if constexpr (std::is_same_v) { + auto* arr_value = reinterpret_cast(array) + index / 2; + const auto [lo, hi] = Int4::unpack_u8(*arr_value); + + if (index % 2 == 0) { + *arr_value = Int4::pack_u8(value, hi); + } else { + *arr_value = Int4::pack_u8(lo, value); + } + } else { + reinterpret_cast(array)[index] = value; + } +} + +} // namespace kai::test diff --git a/test/common/numeric_limits.hpp b/test/common/numeric_limits.hpp new file mode 100644 index 00000000..e5950810 --- /dev/null +++ b/test/common/numeric_limits.hpp @@ -0,0 +1,39 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#include "test/common/int4.hpp" + +namespace kai::test { + +/// Highest finite value of type `T`. +template +inline constexpr T numeric_highest = std::numeric_limits::max(); + +/// Highest finite value of type `T`. +template <> +inline constexpr UInt4 numeric_highest{15}; + +/// Highest finite value of type `T`. +template <> +inline constexpr Int4 numeric_highest{7}; + +/// Lowest finite value of type `T`. +template +inline constexpr T numeric_lowest = std::numeric_limits::lowest(); + +/// Lowest finite value of type `T`. +template <> +inline constexpr UInt4 numeric_lowest{0}; + +/// Lowest finite value of type `T`. +template <> +inline constexpr Int4 numeric_lowest{-8}; + +} // namespace kai::test diff --git a/test/common/printer.cpp b/test/common/printer.cpp new file mode 100644 index 00000000..368ae7db --- /dev/null +++ b/test/common/printer.cpp @@ -0,0 +1,112 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include + +#include "src/kai_common.h" +#include "test/common/data_format.hpp" +#include "test/common/data_type.hpp" +#include "test/common/int4.hpp" + +namespace kai::test { + +namespace { + +inline void print_data(std::ostream& os, const uint8_t* data, size_t len, DataType data_type) { + if (data_type == DataType::QSU4) { + for (size_t i = 0; i < len / 2; ++i) { + const auto [low, high] = UInt4::unpack_u8(data[i]); + os << static_cast(low) << ", " << static_cast(high) << ", "; + } + } else if (data_type == DataType::QSI4) { + for (size_t i = 0; i < len / 2; ++i) { + const auto [low, high] = Int4::unpack_u8(data[i]); + os << static_cast(low) << ", " << static_cast(high) << ", "; + } + } else { + for (size_t i = 0; i < len; ++i) { + switch (data_type) { + case DataType::FP32: + os << reinterpret_cast(data)[i]; + break; + + case DataType::I32: + os << reinterpret_cast(data)[i]; + break; + + case DataType::QI8: + os << static_cast(reinterpret_cast(data)[i]); + break; + + default: + KAI_ERROR("Unsupported data type!"); + } + + os << ", "; + } + } +} + +void print_matrix_raw(std::ostream& os, const uint8_t* data, DataType data_type, size_t height, size_t width) { + const auto row_stride = width * data_type_size_in_bits(data_type) / 8; + + os << "[\n"; + for (size_t y = 0; y < height; ++y) { + os << " ["; + print_data(os, data + y * row_stride, width, data_type); + os << "],\n"; + } + os << "]\n"; +} + +void print_matrix_per_row( + std::ostream& os, const uint8_t* data, const DataFormat& format, size_t height, size_t width) { + const auto block_height = format.block_height(); + const auto num_blocks = (height + block_height - 1) / block_height; + + const auto block_data_bytes = block_height * width * data_type_size_in_bits(format.data_type()) / 8; + const auto block_offsets_bytes = block_height * data_type_size_in_bits(format.zero_point_data_type()) / 8; + const auto block_scales_bytes = block_height * data_type_size_in_bits(format.scale_data_type()) / 8; + + os << "[\n"; + for (size_t y = 0; y < num_blocks; ++y) { + os << " {\"offsets\": ["; + print_data(os, data, block_height, format.zero_point_data_type()); + os << "], \"data\": ["; + print_data(os, data + block_offsets_bytes, block_height * width, format.data_type()); + os << "], \"scales\": ["; + print_data(os, data + block_offsets_bytes + block_data_bytes, block_height, format.scale_data_type()); + os << "]},\n"; + + data += block_offsets_bytes + block_data_bytes + block_scales_bytes; + } + os << "]\n"; +} + +} // namespace + +void print_matrix( + std::ostream& os, std::string_view name, const void* data, const DataFormat& format, size_t height, size_t width) { + os << name << " = "; + + switch (format.quantization_format()) { + case DataFormat::QuantizationFormat::NONE: + print_matrix_raw(os, reinterpret_cast(data), format.data_type(), height, width); + break; + + case DataFormat::QuantizationFormat::PER_ROW: + print_matrix_per_row(os, reinterpret_cast(data), format, height, width); + break; + + default: + KAI_ERROR("Unsupported quantization packing format!"); + } +} + +} // namespace kai::test diff --git a/test/common/printer.hpp b/test/common/printer.hpp new file mode 100644 index 00000000..ec6213ae --- /dev/null +++ b/test/common/printer.hpp @@ -0,0 +1,28 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include + +namespace kai::test { + +class DataFormat; + +/// Prints the matrix data to the output stream. +/// +/// @param[in] os Output stream to write the data to. +/// @param[in] name Matrix name. +/// @param[in] data Data buffer. +/// @param[in] format Data format. +/// @param[in] height Number of rows. +/// @param[in] width Number of columns. +void print_matrix( + std::ostream& os, std::string_view name, const void* data, const DataFormat& format, size_t height, size_t width); + +} // namespace kai::test diff --git a/test/common/type_traits.hpp b/test/common/type_traits.hpp new file mode 100644 index 00000000..2267432a --- /dev/null +++ b/test/common/type_traits.hpp @@ -0,0 +1,79 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +namespace kai::test { + +class UInt4; +class Int4; + +/// `true` if `T` is unsigned numeric type. +template +inline constexpr bool is_unsigned = std::is_unsigned_v; + +/// `true` if `T` is unsigned numeric type. +template <> +inline constexpr bool is_unsigned = true; + +/// `true` if `T` is unsigned numeric type. +template <> +inline constexpr bool is_unsigned = true; + +/// `true` if `T` is signed numeric type. +template +inline constexpr bool is_signed = std::is_signed_v; + +/// `true` if `T` is signed numeric type. +template <> +inline constexpr bool is_signed = false; + +/// `true` if `T` is signed numeric type. +template <> +inline constexpr bool is_signed = false; + +/// `true` if `T` is integral numeric type. +template +inline constexpr bool is_integral = std::is_integral_v; + +/// `true` if `T` is integral numeric type. +template <> +inline constexpr bool is_integral = true; + +/// `true` if `T` is integral numeric type. +template <> +inline constexpr bool is_integral = true; + +/// `true` if `T` is floating-point type. +template +inline constexpr bool is_floating_point = std::is_floating_point_v; + +/// Signed version of type `T`. +template +struct make_signed { + using type = std::make_signed_t; +}; + +/// Signed version of type `T`. +template <> +struct make_signed { + using type = Int4; +}; + +/// Signed version of type `T`. +template <> +struct make_signed { + using type = Int4; +}; + +/// Signed version of type `T`. +template +using make_signed_t = typename make_signed::type; + +} // namespace kai::test diff --git a/test/reference/round.cpp b/test/reference/round.cpp new file mode 100644 index 00000000..52a2c557 --- /dev/null +++ b/test/reference/round.cpp @@ -0,0 +1,36 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/reference/round.hpp" + +#include +#include + +namespace kai::test { + +int32_t round_to_nearest_even_i32(float value) { + int32_t rounded = 0; + asm("fcvtns %w[output], %s[input]" : [output] "=r"(rounded) : [input] "w"(value)); + return rounded; +} + +size_t round_to_nearest_even_usize(float value) { + static_assert(sizeof(size_t) == sizeof(uint64_t)); + + uint64_t rounded = 0; + asm("fcvtns %x[output], %s[input]" : [output] "=r"(rounded) : [input] "w"(value)); + return rounded; +} + +size_t round_up_multiple(size_t a, size_t b) { + return ((a + b - 1) / b) * b; +} + +size_t round_down_multiple(size_t a, size_t b) { + return (a / b) * b; +} + +} // namespace kai::test diff --git a/test/reference/round.hpp b/test/reference/round.hpp new file mode 100644 index 00000000..7503936f --- /dev/null +++ b/test/reference/round.hpp @@ -0,0 +1,62 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +namespace kai::test { + +/// Rounds the specified value to nearest with tie to even. +/// +/// For example: +/// +/// * 0.4 is rounded to 0. +/// * 0.5 is rounded to 0 (as 0 is the nearest even value). +/// * 0.6 is rounded to 1. +/// * 1.4 is rounded to 1. +/// * 1.5 is rounded to 2 (as 2 is the nearest even value). +/// * 1.6 is rounded to 2. +/// +/// @param[in] value Value to be rounded. +/// +/// @return The rounded value. +int32_t round_to_nearest_even_i32(float value); + +/// Rounds the specified value to nearest with tie to even. +/// +/// For example: +/// +/// * 0.4 is rounded to 0. +/// * 0.5 is rounded to 0 (as 0 is the nearest even value). +/// * 0.6 is rounded to 1. +/// * 1.4 is rounded to 1. +/// * 1.5 is rounded to 2 (as 2 is the nearest even value). +/// * 1.6 is rounded to 2. +/// +/// @param[in] value Value to be rounded. +/// +/// @return The rounded value. +size_t round_to_nearest_even_usize(float value); + +/// Rounds up the input value to the multiple of the unit value. +/// +/// @param[in] a Input value. +/// @param[in] b Unit value. +/// +/// @return The rounded value. +size_t round_up_multiple(size_t a, size_t b); + +/// Rounds down the input value to the multiple of the unit value. +/// +/// @param[in] a Input value. +/// @param[in] b Unit value. +/// +/// @return The rounded value. +size_t round_down_multiple(size_t a, size_t b); + +} // namespace kai::test -- GitLab From fed53b99e792afd687b8f6d11a8edcfe7e8cfa14 Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Tue, 30 Apr 2024 13:26:23 +0100 Subject: [PATCH 02/11] Add matrix portion and comparison * Add matrix portion and rectangular region to specify a sub-matrix from a full matrix. * Add matrix comparison and mismatch handler to compare two matrices. - Support both raw and per-row quantization formats. Signed-off-by: Viet-Hoa Do --- CMakeLists.txt | 3 + test/common/compare.cpp | 270 +++++++++++++++++++++++++++++++++ test/common/compare.hpp | 125 +++++++++++++++ test/common/matrix_portion.cpp | 64 ++++++++ test/common/matrix_portion.hpp | 68 +++++++++ test/common/rect.cpp | 41 +++++ test/common/rect.hpp | 51 +++++++ 7 files changed, 622 insertions(+) create mode 100644 test/common/compare.cpp create mode 100644 test/common/compare.hpp create mode 100644 test/common/matrix_portion.cpp create mode 100644 test/common/matrix_portion.hpp create mode 100644 test/common/rect.cpp create mode 100644 test/common/rect.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index b6d9662c..5266a3a5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -60,6 +60,9 @@ if(KLEIDIAI_BUILD_TESTS) test/common/data_format.cpp test/common/printer.cpp test/common/int4.cpp + test/common/compare.cpp + test/common/matrix_portion.cpp + test/common/rect.cpp test/reference/round.cpp ) diff --git a/test/common/compare.cpp b/test/common/compare.cpp new file mode 100644 index 00000000..206ec0e5 --- /dev/null +++ b/test/common/compare.cpp @@ -0,0 +1,270 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/common/compare.hpp" + +#include +#include +#include +#include + +#include "src/kai_common.h" +#include "test/common/data_format.hpp" +#include "test/common/data_type.hpp" +#include "test/common/int4.hpp" +#include "test/common/logging.hpp" +#include "test/common/memory.hpp" +#include "test/common/printer.hpp" +#include "test/common/rect.hpp" + +namespace kai::test { + +namespace { + +/// Calculates the absolute and relative errors. +/// +/// @param[in] imp Value under test. +/// @param[in] ref Reference value. +/// +/// @return The absolute error and relative error. +template +std::tuple calculate_error(T imp, T ref) { + const auto imp_f = static_cast(imp); + const auto ref_f = static_cast(ref); + + const auto abs_error = std::abs(imp_f - ref_f); + const auto rel_error = ref_f != 0 ? abs_error / std::abs(ref_f) : 0.0F; + + return {abs_error, rel_error}; +} + +/// Compares matrices with per-row quantization. +template +bool compare_raw( + const void* imp_data, const void* ref_data, size_t full_height, size_t full_width, const Rect& rect, + MismatchHandler& handler) { + for (size_t y = 0; y < full_height; ++y) { + for (size_t x = 0; x < full_width; ++x) { + const auto in_roi = + y >= rect.start_row() && y < rect.end_row() && x >= rect.start_col() && x < rect.end_col(); + + const auto imp_value = read_array(imp_data, y * full_width + x); + const auto ref_value = in_roi ? read_array(ref_data, y * full_width + x) : 0; + + const auto [abs_err, rel_err] = calculate_error(imp_value, ref_value); + + if (abs_err != 0 || rel_err != 0) { + const auto notifying = !in_roi || handler.handle_data(abs_err, rel_err); + + if (notifying) { + KAI_LOGE("Mismatched data at (", y, ", ", x, "): actual = ", imp_value, ", expected: ", ref_value); + } + } + } + } + + return handler.success(full_height * full_width); +} + +/// Compares matrices with per-row quantization. +template +bool compare_per_row( + const void* imp_data, const void* ref_data, const DataFormat& format, size_t full_height, size_t full_width, + const Rect& rect, MismatchHandler& handler) { + const auto block_height = format.block_height(); + const auto block_width = format.block_width(); + + KAI_ASSUME(format.scheduler_block_height(full_height) == block_height); + KAI_ASSUME(format.scheduler_block_width(full_width) == full_width); + KAI_ASSUME(rect.start_col() == 0); + KAI_ASSUME(rect.width() == full_width); + + const auto data_bits = size_in_bits; + + const auto num_groups = (full_height + block_height - 1) / block_height; + const auto group_num_blocks = (full_width + block_width - 1) / block_width; + + const auto group_offsets_bytes = block_height * sizeof(Offset); + const auto group_scales_bytes = block_height * sizeof(Scale); + const auto block_data_bytes = block_height * block_width * data_bits / 8; + + const auto begin_group = rect.start_row() / block_height; + const auto end_group = rect.end_row() / block_height; + + const auto* imp_ptr = reinterpret_cast(imp_data); + const auto* ref_ptr = reinterpret_cast(ref_data); + + for (size_t group_no = 0; group_no < num_groups; ++group_no) { + const auto in_roi = group_no >= begin_group && group_no < end_group; + + // Checks the quantization offsets. + for (size_t i = 0; i < block_height; ++i) { + const auto imp_offset = reinterpret_cast(imp_ptr)[i]; + const Offset ref_offset = in_roi ? reinterpret_cast(ref_ptr)[i] : 0; + const auto [abs_err, rel_err] = calculate_error(imp_offset, ref_offset); + + if (abs_err != 0 || rel_err != 0) { + handler.mark_as_failed(); + + const auto raw_row = group_no * block_height + i; + KAI_LOGE( + "Mismatched quantization offset ", raw_row, ": actual = ", imp_offset, ", expected: ", ref_offset); + } + } + + imp_ptr += group_offsets_bytes; + ref_ptr += group_offsets_bytes; + + // Checks the data. + for (size_t block_no = 0; block_no < group_num_blocks; ++block_no) { + for (size_t y = 0; y < block_height; ++y) { + for (size_t x = 0; x < block_width; ++x) { + const auto imp_data = read_array(imp_ptr, y * block_width + x); + const Data ref_data = in_roi ? read_array(ref_ptr, y * block_width + x) : Data(0); + const auto [abs_err, rel_err] = calculate_error(imp_data, ref_data); + + if (abs_err != 0 || rel_err != 0) { + const auto notifying = !in_roi || handler.handle_data(abs_err, rel_err); + + if (notifying) { + const auto raw_row = group_no * block_height + y; + const auto raw_col = block_no * block_width + x; + + KAI_LOGE( + "Mismatched data at (", raw_row, ", ", raw_col, "): actual = ", imp_data, + ", expected: ", ref_data); + } + } + } + } + + imp_ptr += block_data_bytes; + ref_ptr += block_data_bytes; + } + + // Checks the quantization scales. + for (size_t i = 0; i < block_height; ++i) { + const auto imp_scale = reinterpret_cast(imp_ptr)[i]; + const Scale ref_scale = in_roi ? reinterpret_cast(ref_ptr)[i] : 0; + const auto [abs_err, rel_err] = calculate_error(imp_scale, ref_scale); + + if (abs_err != 0 || rel_err != 0) { + handler.mark_as_failed(); + + const auto raw_row = group_no * block_height + i; + KAI_LOGE( + "Mismatched quantization scale ", raw_row, ": actual = ", imp_scale, ", expected: ", ref_scale); + } + } + + imp_ptr += group_scales_bytes; + ref_ptr += group_scales_bytes; + } + + return handler.success(rect.height() * full_width); +} + +} // namespace + +bool compare( + const void* imp_data, const void* ref_data, const DataFormat& format, size_t full_height, size_t full_width, + const Rect& rect, MismatchHandler& handler) { + const auto data_type = format.data_type(); + const auto scale_dt = format.scale_data_type(); + const auto offset_dt = format.zero_point_data_type(); + + switch (format.quantization_format()) { + case DataFormat::QuantizationFormat::NONE: + switch (data_type) { + case DataType::FP32: + return compare_raw(imp_data, ref_data, full_height, full_width, rect, handler); + + default: + break; + } + + case DataFormat::QuantizationFormat::PER_ROW: + if (data_type == DataType::QI8 && scale_dt == DataType::FP32 && offset_dt == DataType::I32) { + return compare_per_row( + imp_data, ref_data, format, full_height, full_width, rect, handler); + } else if (data_type == DataType::QSI4 && scale_dt == DataType::FP32 && offset_dt == DataType::I32) { + return compare_per_row( + imp_data, ref_data, format, full_height, full_width, rect, handler); + } + + break; + + default: + break; + } + + KAI_ERROR("Unsupported format!"); +} + +// ===================================================================================================================== + +DefaultMismatchHandler::DefaultMismatchHandler( + float abs_error_threshold, float rel_error_threshold, size_t abs_mismatched_threshold, + float rel_mismatched_threshold) : + _abs_error_threshold(abs_error_threshold), + _rel_error_threshold(rel_error_threshold), + _abs_mismatched_threshold(abs_mismatched_threshold), + _rel_mismatched_threshold(rel_mismatched_threshold), + _num_mismatches(0), + _failed(false) { +} + +DefaultMismatchHandler::DefaultMismatchHandler(const DefaultMismatchHandler& rhs) : + _abs_error_threshold(rhs._abs_error_threshold), + _rel_error_threshold(rhs._rel_error_threshold), + _abs_mismatched_threshold(rhs._abs_mismatched_threshold), + _rel_mismatched_threshold(rhs._rel_mismatched_threshold), + _num_mismatches(0), + _failed(false) { + // Cannot copy mismatch handler that is already in use. + KAI_ASSUME(rhs._num_mismatches == 0); + KAI_ASSUME(!rhs._failed); +} + +DefaultMismatchHandler& DefaultMismatchHandler::operator=(const DefaultMismatchHandler& rhs) { + if (this != &rhs) { + // Cannot copy mismatch handler that is already in use. + KAI_ASSUME(rhs._num_mismatches == 0); + KAI_ASSUME(!rhs._failed); + + _abs_error_threshold = rhs._abs_error_threshold; + _rel_error_threshold = rhs._rel_error_threshold; + _abs_mismatched_threshold = rhs._abs_mismatched_threshold; + _rel_mismatched_threshold = rhs._rel_mismatched_threshold; + } + + return *this; +} + +bool DefaultMismatchHandler::handle_data(float absolute_error, float relative_error) { + const auto mismatched = absolute_error > _abs_error_threshold && relative_error > _rel_error_threshold; + + if (mismatched) { + ++_num_mismatches; + } + + return mismatched; +} + +void DefaultMismatchHandler::mark_as_failed() { + _failed = true; +} + +bool DefaultMismatchHandler::success(size_t num_checks) const { + if (_failed) { + return false; + } + + const auto mismatched_rate = static_cast(_num_mismatches) / static_cast(num_checks); + return _num_mismatches <= _abs_mismatched_threshold || mismatched_rate <= _rel_mismatched_threshold; +} + +} // namespace kai::test diff --git a/test/common/compare.hpp b/test/common/compare.hpp new file mode 100644 index 00000000..3bca0a5c --- /dev/null +++ b/test/common/compare.hpp @@ -0,0 +1,125 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +namespace kai::test { + +class DataFormat; +class Rect; +class MismatchHandler; + +/// Compares two matrices to check whether they are matched. +/// +/// @param[in] imp_data Data buffer of the actual implementation matrix. +/// @param[in] ref_data Data buffer of the reference implementation matrix. +/// @param[in] format Data format. +/// @param[in] full_height Height of the full matrix. +/// @param[in] full_width Width of the full matrix. +/// @param[in] rect Rectangular region of the matrix that is populated with data. +/// @param[in] handler Mismatch handler. +/// +/// @return `true` if the two matrices are considered matched. +bool compare( + const void* imp_data, const void* ref_data, const DataFormat& format, size_t full_height, size_t full_width, + const Rect& rect, MismatchHandler& handler); + +/// Handles mismatches found by @ref validate function. +class MismatchHandler { +public: + /// Constructor. + MismatchHandler() = default; + + /// Destructor. + virtual ~MismatchHandler() = default; + + /// Copy constructor. + MismatchHandler(const MismatchHandler&) = default; + + /// Copy assignment. + MismatchHandler& operator=(const MismatchHandler&) = default; + + /// Move constructor. + MismatchHandler(MismatchHandler&&) noexcept = default; + + /// Move assignment. + MismatchHandler& operator=(MismatchHandler&&) noexcept = default; + + /// Handles new mismatch result. + /// + /// This method must be called even when no error is detected. + /// + /// @param[in] absolute_error Absolute error. + /// @param[in] relative_error Relative error. + /// + /// @return `true` if the mismatch is sufficiently large to be logged as real mismatch. + virtual bool handle_data(float absolute_error, float relative_error) = 0; + + /// Marks the result as failed. + /// + /// It is zero tolerance if the data point is considered impossible to have mismatch + /// regardless of implementation method. + /// These normally include data point outside if the portion of interest (these must be 0) + /// and data point belongs to quantization information. + virtual void mark_as_failed() = 0; + + /// Returns a value indicating whether the two matrices are considered matched. + /// + /// @param[in] num_checks Total number of data points that have been checked. + /// + /// @return `true` if the two matrices are considered matched. + [[nodiscard]] virtual bool success(size_t num_checks) const = 0; +}; + +/// This mismatch handler considers two values being mismatched when both the relative error +/// and the absolute error exceed their respective thresholds. +/// +/// This mismatch handler considers two matrices being mismatched when the number of mismatches +/// exceed both the relative and absolute thresholds. +class DefaultMismatchHandler final : public MismatchHandler { +public: + /// Creates a new mismatch handler. + /// + /// @param[in] abs_error_threshold Threshold for absolute error + /// @param[in] rel_error_threshold Threshold for relative error. + /// @param[in] abs_mismatched_threshold Threshold for the number of mismatched data points. + /// @param[in] rel_mismatched_threshold Threshold for the ratio of mismatched data points. + DefaultMismatchHandler( + float abs_error_threshold, float rel_error_threshold, size_t abs_mismatched_threshold, + float rel_mismatched_threshold); + + /// Destructur. + ~DefaultMismatchHandler() = default; + + /// Copy constructor. + DefaultMismatchHandler(const DefaultMismatchHandler& rhs); + + /// Copy assignment. + DefaultMismatchHandler& operator=(const DefaultMismatchHandler& rhs); + + /// Move constructor. + DefaultMismatchHandler(DefaultMismatchHandler&& rhs) noexcept = default; + + /// Move assignment. + DefaultMismatchHandler& operator=(DefaultMismatchHandler&& rhs) noexcept = default; + + bool handle_data(float absolute_error, float relative_error) override; + void mark_as_failed() override; + [[nodiscard]] bool success(size_t num_checks) const override; + +private: + float _abs_error_threshold; + float _rel_error_threshold; + size_t _abs_mismatched_threshold; + float _rel_mismatched_threshold; + + size_t _num_mismatches; + bool _failed; +}; + +} // namespace kai::test diff --git a/test/common/matrix_portion.cpp b/test/common/matrix_portion.cpp new file mode 100644 index 00000000..abf795df --- /dev/null +++ b/test/common/matrix_portion.cpp @@ -0,0 +1,64 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/common/matrix_portion.hpp" + +#include +#include + +#include "test/common/rect.hpp" +#include "test/reference/round.hpp" + +namespace kai::test { + +MatrixPortion::MatrixPortion(float start_row, float start_col, float height, float width) : + _start_row(start_row), _start_col(start_col), _height(height), _width(width) { +} + +float MatrixPortion::start_row() const { + return _start_row; +} + +float MatrixPortion::start_col() const { + return _start_col; +} + +float MatrixPortion::height() const { + return _height; +} + +float MatrixPortion::width() const { + return _width; +} + +Rect MatrixPortion::compute_portion( + size_t full_height, size_t full_width, size_t scheduler_block_height, size_t scheduler_block_width) const { + const auto start_row_f = std::clamp(_start_row, 0, 1); + const auto start_col_f = std::clamp(_start_col, 0, 1); + const auto height_f = std::clamp(_height, 0, 1); + const auto width_f = std::clamp(_width, 0, 1); + + auto start_row = round_to_nearest_even_usize(start_row_f * static_cast(full_height)); + auto start_col = round_to_nearest_even_usize(start_col_f * static_cast(full_width)); + auto height = round_to_nearest_even_usize(height_f * static_cast(full_height)); + auto width = round_to_nearest_even_usize(width_f * static_cast(full_width)); + + start_row = round_down_multiple(start_row, scheduler_block_height); + start_col = round_down_multiple(start_col, scheduler_block_width); + + start_row = std::min(start_row, round_down_multiple(full_height, scheduler_block_height)); + start_col = std::min(start_col, round_down_multiple(full_width, scheduler_block_width)); + + height = round_up_multiple(height, scheduler_block_height); + width = round_up_multiple(width, scheduler_block_width); + + height = std::min(height, full_height - start_row); + width = std::min(width, full_width - start_col); + + return {start_row, start_col, height, width}; +} + +} // namespace kai::test diff --git a/test/common/matrix_portion.hpp b/test/common/matrix_portion.hpp new file mode 100644 index 00000000..02f9f548 --- /dev/null +++ b/test/common/matrix_portion.hpp @@ -0,0 +1,68 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#include "test/common/rect.hpp" + +namespace kai::test { + +/// Portion of a matrix. +/// +/// This class is used to define the sub-matrix a test is running and checking. +/// +/// This is the relative version of @ref Rect. +class MatrixPortion { +public: + /// Creates a new matrix portion. + /// + /// @param[in] start_row Starting row as the ratio to the height of the matrix. + /// @param[in] start_col Starting column as the ratio to the width of the matrix. + /// @param[in] height Portion height as the ratio to the height of the matrix. + /// @param[in] width Portion width as the ratio to the width of the matrix. + MatrixPortion(float start_row, float start_col, float height, float width); + + /// Gets the starting row as the ratio to the height of the matrix. + [[nodiscard]] float start_row() const; + + /// Gets the starting column as the ratio to the width of the matrix. + [[nodiscard]] float start_col() const; + + /// Gets the portion height as the ratio to the height of the matrix. + [[nodiscard]] float height() const; + + /// Gets the portion width as the ratio to the width of the matrix. + [[nodiscard]] float width() const; + + /// Computes the starting coordinate and the shape of the sub-matrix. + /// + /// Requirements: + /// + /// * The starting coordinate of the sub-matrix shall be aligned with the block boundary. + /// * If it is not the block at the right and/or bottom edge of the full matrix, the height and width + /// of the sub-matrix shall be rounded up to multiple of the block height and width. + /// * If it is the block at the right and/or bottom edge of the full matrix, the height and width + /// of the sub-matrix shall be the rounded up to the edge of the matrix. + /// + /// @param[in] full_height Matrix height. + /// @param[in] full_width Matrix width. + /// @param[in] scheduler_block_height Block height for scheduling purpose. + /// @param[in] scheduler_block_width Block width for scheduling purpose. + /// + /// @return The rectangular region of the matrix. + [[nodiscard]] Rect compute_portion( + size_t full_height, size_t full_width, size_t scheduler_block_height, size_t scheduler_block_width) const; + +private: + float _start_row; + float _start_col; + float _height; + float _width; +}; + +} // namespace kai::test diff --git a/test/common/rect.cpp b/test/common/rect.cpp new file mode 100644 index 00000000..457ac2f3 --- /dev/null +++ b/test/common/rect.cpp @@ -0,0 +1,41 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/common/rect.hpp" + +#include + +namespace kai::test { + +Rect::Rect(size_t start_row, size_t start_col, size_t height, size_t width) : + _start_row(start_row), _start_col(start_col), _height(height), _width(width) { +} + +size_t Rect::start_row() const { + return _start_row; +} + +size_t Rect::start_col() const { + return _start_col; +} + +size_t Rect::height() const { + return _height; +} + +size_t Rect::width() const { + return _width; +} + +size_t Rect::end_row() const { + return _start_row + _height; +} + +size_t Rect::end_col() const { + return _start_col + _width; +} + +} // namespace kai::test diff --git a/test/common/rect.hpp b/test/common/rect.hpp new file mode 100644 index 00000000..92b033b4 --- /dev/null +++ b/test/common/rect.hpp @@ -0,0 +1,51 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +namespace kai::test { + +/// Rectangular region of a matrix. +/// +/// This is the absolute version of @ref MatrixPortion. +class Rect { +public: + /// Creates a new rectangular region of a matrix. + /// + /// @param[in] start_row Starting row index. + /// @param[in] start_col Starting column index. + /// @param[in] height Number of rows. + /// @param[in] width Number of columns. + Rect(size_t start_row, size_t start_col, size_t height, size_t width); + + /// Gets the starting row index. + [[nodiscard]] size_t start_row() const; + + /// Gets the starting column index. + [[nodiscard]] size_t start_col() const; + + /// Gets the number of rows. + [[nodiscard]] size_t height() const; + + /// Gets the number of columns. + [[nodiscard]] size_t width() const; + + /// Gets the end (exclusive) row index. + [[nodiscard]] size_t end_row() const; + + /// Gets the end (exclusive) column index. + [[nodiscard]] size_t end_col() const; + +private: + size_t _start_row; + size_t _start_col; + size_t _height; + size_t _width; +}; + +} // namespace kai::test -- GitLab From f0e107b1a5cb01fe49474b8473ee8b1a180b7c0a Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Tue, 30 Apr 2024 13:34:55 +0100 Subject: [PATCH 03/11] Add reference implementation for simple operators * Add reference implementation for the following simple operators: - Binary element-wise. - Fill. - Quantize. - Reduce. Signed-off-by: Viet-Hoa Do --- .clang-tidy | 3 +- CMakeLists.txt | 4 + test/reference/binary_elementwise.cpp | 149 ++++++++++++++++++++++++++ test/reference/binary_elementwise.hpp | 89 +++++++++++++++ test/reference/fill.cpp | 97 +++++++++++++++++ test/reference/fill.hpp | 27 +++++ test/reference/quantize.cpp | 118 ++++++++++++++++++++ test/reference/quantize.hpp | 60 +++++++++++ test/reference/reduce.cpp | 113 +++++++++++++++++++ test/reference/reduce.hpp | 36 +++++++ 10 files changed, 695 insertions(+), 1 deletion(-) create mode 100644 test/reference/binary_elementwise.cpp create mode 100644 test/reference/binary_elementwise.hpp create mode 100644 test/reference/fill.cpp create mode 100644 test/reference/fill.hpp create mode 100644 test/reference/quantize.cpp create mode 100644 test/reference/quantize.hpp create mode 100644 test/reference/reduce.cpp create mode 100644 test/reference/reduce.hpp diff --git a/.clang-tidy b/.clang-tidy index bebfc13a..4c0dfbe8 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -26,6 +26,7 @@ readability-*, -cppcoreguidelines-avoid-magic-numbers, -readability-simplify-boolean-expr, -bugprone-easily-swappable-parameters, --cppcoreguidelines-pro-bounds-pointer-arithmetic +-cppcoreguidelines-pro-bounds-pointer-arithmetic, +-performance-enum-size ' ... diff --git a/CMakeLists.txt b/CMakeLists.txt index 5266a3a5..b7e6786d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -63,6 +63,10 @@ if(KLEIDIAI_BUILD_TESTS) test/common/compare.cpp test/common/matrix_portion.cpp test/common/rect.cpp + test/reference/binary_elementwise.cpp + test/reference/fill.cpp + test/reference/quantize.cpp + test/reference/reduce.cpp test/reference/round.cpp ) diff --git a/test/reference/binary_elementwise.cpp b/test/reference/binary_elementwise.cpp new file mode 100644 index 00000000..eda5d8ae --- /dev/null +++ b/test/reference/binary_elementwise.cpp @@ -0,0 +1,149 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/reference/binary_elementwise.hpp" + +#include +#include +#include +#include + +#include "src/kai_common.h" +#include "test/common/data_type.hpp" +#include "test/common/int4.hpp" +#include "test/common/memory.hpp" + +namespace kai::test { + +namespace { + +/// Binary element-wise operator. +enum class BinaryElementwiseOperator : uint32_t { + ADD, ///< Addition. + SUB, ///< Subtraction. + MUL, ///< Multiplication. + DIV, ///< Division. +}; + +/// Scalar binary element-wise function. +/// +/// @tparam op Binary element-wise operator to perform. +/// @tparam T Data type. +/// +/// @param[in] lhs LHS operand. +/// @param[in] rhs RHS operand. +/// +/// @return The result of the operation. +template +T scalar_binary_elementwise(T lhs, T rhs) { + if constexpr (op == BinaryElementwiseOperator::ADD) { + return lhs + rhs; + } else if constexpr (op == BinaryElementwiseOperator::SUB) { + return lhs - rhs; + } else if constexpr (op == BinaryElementwiseOperator::MUL) { + return lhs * rhs; + } else if constexpr (op == BinaryElementwiseOperator::DIV) { + return lhs / rhs; + } else { + KAI_ERROR("Unsupported binary element-wise operator!"); + } +} + +/// Binary element-wise function. +/// +/// @tparam op Binary element-wise operator to perform. +/// @tparam T Data type. +/// +/// @param[in] lhs LHS data buffer. +/// @param[in] rhs RHS data buffer. +/// @param[in] lhs_height LHS height. +/// @param[in] lhs_width LHS width. +/// @param[in] rhs_height RHS height. +/// @param[in] rhs_width RHS width. +/// +/// @return The result data buffer. +template +std::vector binary_elementwise_any_op_type( + const void* lhs, const void* rhs, size_t lhs_height, size_t lhs_width, size_t rhs_height, size_t rhs_width) { + const auto height = std::max(lhs_height, rhs_height); + const auto width = std::max(lhs_width, rhs_width); + + std::vector dst; + dst.resize(height * width * size_in_bits / 8); + KAI_ASSUME(width * size_in_bits % 8 == 0); + + for (size_t y = 0; y < height; ++y) { + for (size_t x = 0; x < width; ++x) { + const auto lhs_y = lhs_height > 1 ? y : 0; + const auto lhs_x = lhs_width > 1 ? x : 0; + const auto lhs_value = read_array(lhs, lhs_y * lhs_width + lhs_x); + + const auto rhs_y = rhs_height > 1 ? y : 0; + const auto rhs_x = rhs_width > 1 ? x : 0; + const auto rhs_value = read_array(rhs, rhs_y * rhs_width + rhs_x); + + const auto dst_value = scalar_binary_elementwise(lhs_value, rhs_value); + write_array(dst.data(), y * width + x, dst_value); + } + } + + return dst; +} + +template +std::vector binary_elementwise_any_type( + const void* lhs, DataType lhs_dt, size_t lhs_height, size_t lhs_width, // + const void* rhs, DataType rhs_dt, size_t rhs_height, size_t rhs_width) { + KAI_ASSUME(lhs_dt == rhs_dt); + KAI_ASSUME(lhs_height == 1 || rhs_height == 1 || lhs_height == rhs_height); + KAI_ASSUME(lhs_width == 1 || rhs_width == 1 || lhs_width == rhs_width); + + switch (lhs_dt) { + case DataType::FP32: + return binary_elementwise_any_op_type(lhs, rhs, lhs_height, lhs_width, rhs_height, rhs_width); + + case DataType::I32: + return binary_elementwise_any_op_type(lhs, rhs, lhs_height, lhs_width, rhs_height, rhs_width); + + case DataType::QSU4: + return binary_elementwise_any_op_type(lhs, rhs, lhs_height, lhs_width, rhs_height, rhs_width); + + default: + KAI_ERROR("Unsupported data type!"); + } +} + +} // namespace + +std::vector add( + const void* lhs, DataType lhs_dt, size_t lhs_height, size_t lhs_width, // + const void* rhs, DataType rhs_dt, size_t rhs_height, size_t rhs_width) { + return binary_elementwise_any_type( + lhs, lhs_dt, lhs_height, lhs_width, rhs, rhs_dt, rhs_height, rhs_width); +} + +std::vector sub( + const void* lhs, DataType lhs_dt, size_t lhs_height, size_t lhs_width, // + const void* rhs, DataType rhs_dt, size_t rhs_height, size_t rhs_width) { + return binary_elementwise_any_type( + lhs, lhs_dt, lhs_height, lhs_width, rhs, rhs_dt, rhs_height, rhs_width); +} + +std::vector mul( + const void* lhs, DataType lhs_dt, size_t lhs_height, size_t lhs_width, // + const void* rhs, DataType rhs_dt, size_t rhs_height, size_t rhs_width) { + return binary_elementwise_any_type( + lhs, lhs_dt, lhs_height, lhs_width, rhs, rhs_dt, rhs_height, rhs_width); +} + +std::vector div( + const void* lhs, DataType lhs_dt, size_t lhs_height, size_t lhs_width, // + const void* rhs, DataType rhs_dt, size_t rhs_height, size_t rhs_width) { + return binary_elementwise_any_type( + lhs, lhs_dt, lhs_height, lhs_width, rhs, rhs_dt, rhs_height, rhs_width); +} + +} // namespace kai::test diff --git a/test/reference/binary_elementwise.hpp b/test/reference/binary_elementwise.hpp new file mode 100644 index 00000000..e3d3a9e1 --- /dev/null +++ b/test/reference/binary_elementwise.hpp @@ -0,0 +1,89 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include + +#include "test/common/data_type.hpp" + +namespace kai::test { + +/// Elementwise addition. +/// +/// Broadcasting is supported for any dimension and both LHS and RHS operands. +/// +/// @param[in] lhs LHS data buffer. +/// @param[in] lhs_dt LHS data type. +/// @param[in] lhs_height LHS height. +/// @param[in] lhs_width LHS width. +/// @param[in] rhs RHS data buffer. +/// @param[in] rhs_dt RHS data type. +/// @param[in] rhs_height RHS height. +/// @param[in] rhs_width RHS width. +/// +/// @return The result matrix. +std::vector add( + const void* lhs, DataType lhs_dt, size_t lhs_height, size_t lhs_width, // + const void* rhs, DataType rhs_dt, size_t rhs_height, size_t rhs_width); + +/// Elementwise subtraction. +/// +/// Broadcasting is supported for any dimension and both LHS and RHS operands. +/// +/// @param[in] lhs LHS data buffer. +/// @param[in] lhs_dt LHS data type. +/// @param[in] lhs_height LHS height. +/// @param[in] lhs_width LHS width. +/// @param[in] rhs RHS data buffer. +/// @param[in] rhs_dt RHS data type. +/// @param[in] rhs_height RHS height. +/// @param[in] rhs_width RHS width. +/// +/// @return The result matrix. +std::vector sub( + const void* lhs, DataType lhs_dt, size_t lhs_height, size_t lhs_width, // + const void* rhs, DataType rhs_dt, size_t rhs_height, size_t rhs_width); + +/// Elementwise multiplication. +/// +/// Broadcasting is supported for any dimension and both LHS and RHS operands. +/// +/// @param[in] lhs LHS data buffer. +/// @param[in] lhs_dt LHS data type. +/// @param[in] lhs_height LHS height. +/// @param[in] lhs_width LHS width. +/// @param[in] rhs RHS data buffer. +/// @param[in] rhs_dt RHS data type. +/// @param[in] rhs_height RHS height. +/// @param[in] rhs_width RHS width. +/// +/// @return The result matrix. +std::vector mul( + const void* lhs, DataType lhs_dt, size_t lhs_height, size_t lhs_width, // + const void* rhs, DataType rhs_dt, size_t rhs_height, size_t rhs_width); + +/// Elementwise division. +/// +/// Broadcasting is supported for any dimension and both LHS and RHS operands. +/// +/// @param[in] lhs LHS data buffer. +/// @param[in] lhs_dt LHS data type. +/// @param[in] lhs_height LHS height. +/// @param[in] lhs_width LHS width. +/// @param[in] rhs RHS data buffer. +/// @param[in] rhs_dt RHS data type. +/// @param[in] rhs_height RHS height. +/// @param[in] rhs_width RHS width. +/// +/// @return The result matrix. +std::vector div( + const void* lhs, DataType lhs_dt, size_t lhs_height, size_t lhs_width, // + const void* rhs, DataType rhs_dt, size_t rhs_height, size_t rhs_width); + +} // namespace kai::test diff --git a/test/reference/fill.cpp b/test/reference/fill.cpp new file mode 100644 index 00000000..d0fd8b5c --- /dev/null +++ b/test/reference/fill.cpp @@ -0,0 +1,97 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/reference/fill.hpp" + +#include +#include +#include +#include +#include +#include + +#include "src/kai_common.h" +#include "test/common/data_format.hpp" +#include "test/common/data_type.hpp" +#include "test/common/int4.hpp" +#include "test/common/memory.hpp" + +namespace kai::test { + +namespace { + +template +std::vector fill_matrix_raw(size_t height, size_t width, std::function gen) { + const auto size = height * width * size_in_bits / 8; + KAI_ASSUME(width * size_in_bits % 8 == 0); + + std::vector data; + data.resize(size); + auto ptr = reinterpret_cast(data.data()); + + for (size_t y = 0; y < height; ++y) { + for (size_t x = 0; x < width; ++x) { + write_array(ptr, y * width + x, gen(y, x)); + } + } + + return data; +} + +template +std::vector fill_matrix_random_raw(size_t height, size_t width, uint64_t seed) { + using TDist = std::conditional_t< + std::is_floating_point_v, std::uniform_real_distribution, std::uniform_int_distribution>; + + std::mt19937 rnd(seed); + TDist dist; + + return fill_matrix_raw(height, width, [&](size_t, size_t) { return dist(rnd); }); +} + +template <> +std::vector fill_matrix_random_raw(size_t height, size_t width, uint64_t seed) { + std::mt19937 rnd(seed); + std::uniform_int_distribution dist(-8, 7); + + return fill_matrix_raw(height, width, [&](size_t, size_t) { return Int4(dist(rnd)); }); +} + +template <> +std::vector fill_matrix_random_raw(size_t height, size_t width, uint64_t seed) { + std::mt19937 rnd(seed); + std::uniform_int_distribution dist(0, 15); + + return fill_matrix_raw(height, width, [&](size_t, size_t) { return UInt4(dist(rnd)); }); +} + +} // namespace + +std::vector fill_matrix_random(size_t height, size_t width, const DataFormat& format, uint64_t seed) { + switch (format.quantization_format()) { + case DataFormat::QuantizationFormat::NONE: + switch (format.data_type()) { + case DataType::FP32: + return fill_matrix_random_raw(height, width, seed); + + case DataType::QSU4: + return fill_matrix_random_raw(height, width, seed); + + case DataType::QSI4: + return fill_matrix_random_raw(height, width, seed); + + default: + KAI_ERROR("Unsupported data type!"); + } + + break; + + default: + KAI_ERROR("Unsupported data format!"); + } +} + +} // namespace kai::test diff --git a/test/reference/fill.hpp b/test/reference/fill.hpp new file mode 100644 index 00000000..09138145 --- /dev/null +++ b/test/reference/fill.hpp @@ -0,0 +1,27 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include + +namespace kai::test { + +class DataFormat; + +/// Creates a new matrix filled with random data. +/// +/// @param[in] height Number of rows. +/// @param[in] width Number of columns. +/// @param[in] format Data format. +/// @param[in] seed Random seed. +/// +/// @return The data buffer for the matrix. +std::vector fill_matrix_random(size_t height, size_t width, const DataFormat& format, uint64_t seed); + +} // namespace kai::test diff --git a/test/reference/quantize.cpp b/test/reference/quantize.cpp new file mode 100644 index 00000000..fab745e6 --- /dev/null +++ b/test/reference/quantize.cpp @@ -0,0 +1,118 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/reference/quantize.hpp" + +#include +#include +#include +#include +#include +#include + +#include "src/kai_common.h" +#include "test/common/data_type.hpp" +#include "test/common/int4.hpp" +#include "test/common/memory.hpp" +#include "test/common/numeric_limits.hpp" +#include "test/common/type_traits.hpp" +#include "test/reference/round.hpp" + +namespace kai::test { + +std::tuple get_qi8_scale_zero_point_from_range(float min_value, float max_value) { + constexpr float q_min = std::numeric_limits::min(); + constexpr float q_max = std::numeric_limits::max(); + + if (min_value > 0) { + min_value = 0; + } + + if (max_value < 0) { + max_value = 0; + } + + // The reason for computing the inverted scale first is to make it bit-perfect with quantized packing kernels. + // If those kernels don't do it this way anymore, it makes more sense to calculate the scale directly. + const float inv_scale = max_value != min_value ? (q_max - q_min) / (max_value - min_value) : 1.0F; + const float scale = 1.0F / inv_scale; + + const float scaled_min = min_value / scale; + const float scaled_max = max_value / scale; + + const float zero_point_f = -(scaled_min + q_min) < scaled_max + q_max ? scaled_min - q_min : scaled_max - q_max; + const int32_t zero_point = round_to_nearest_even_i32(zero_point_f); + + return {scale, zero_point}; +} + +int8_t quantize_i8_fp32(float value, float scale, int32_t zero_point) { + return static_cast(std::clamp( + round_to_nearest_even_i32(value / scale) - zero_point, std::numeric_limits::min(), + std::numeric_limits::max())); +} + +namespace { + +template +std::vector dequantize_any_type( + const void* data, const void* scales, const void* zero_points, // + QuantizationMethod method, bool is_asymm, size_t height, size_t width) { + static_assert(is_floating_point); + static_assert(is_integral); + + std::vector dst; + dst.resize(height * width * sizeof(Output)); + KAI_ASSUME(size_in_bits % 8 == 0); + + auto scale = read_array(scales, 0); + KAI_UNUSED(is_asymm); + KAI_UNUSED(zero_points); + auto zero_point = is_asymm ? read_array(zero_points, 0) : // + -static_cast(numeric_lowest>); + + for (size_t y = 0; y < height; ++y) { + if (method == QuantizationMethod::PER_ROW) { + scale = read_array(scales, y); + + if (is_asymm) { + zero_point = read_array(zero_points, y); + } + } + + for (size_t x = 0; x < width; ++x) { + const ZeroPoint input = read_array(data, y * width + x); + const Scale output = Scale(input - zero_point) * scale; + write_array(dst.data(), y * width + x, output); + } + } + + return dst; +} + +} // namespace + +std::vector dequantize( + const void* data, const void* scales, const void* zero_points, // + DataType src_dt, DataType dst_dt, QuantizationMethod method, // + size_t height, size_t width) { + switch (src_dt) { + case DataType::QSU4: + switch (dst_dt) { + case DataType::FP32: + return dequantize_any_type( + data, scales, zero_points, method, false, height, width); + + default: + KAI_ERROR("Unsupported destination data type!"); + } + + default: + KAI_ERROR("Unsupported source data type!"); + } +} + +} // namespace kai::test diff --git a/test/reference/quantize.hpp b/test/reference/quantize.hpp new file mode 100644 index 00000000..904ac5fc --- /dev/null +++ b/test/reference/quantize.hpp @@ -0,0 +1,60 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include + +#include "test/common/data_type.hpp" + +namespace kai::test { + +/// Quantization method. +enum class QuantizationMethod : uint32_t { + PER_MATRIX, ///< Per-matrix, i.e. one quantization scale and zero point for the entire matrix. + PER_ROW, ///< Per-row, i.e. one quantization scale and zero point for each row. +}; + +/// Calculates the quantization information for 8-bit signed asymmetric type from the value range. +/// +/// @param[in] min_value Minimum value. +/// @param[in] max_value Maximum value. +/// +/// @return The scale and zero point. +std::tuple get_qi8_scale_zero_point_from_range(float min_value, float max_value); + +/// Quantizes the single-precision floating-point value using 8-bit asymmetric quantization. +/// +/// Formula: `q = f / scale + zero_point` where `q` is quantized value and `f` is floating-point value. +/// +/// @param[in] value Value to be quantized. +/// @param[in] scale Scale. +/// @param[in] zero_point Zero point. +/// +/// @return The quantized value. +int8_t quantize_i8_fp32(float value, float scale, int32_t zero_point); + +/// Dequantizes the matrix to floating-point. +/// +/// @param[in] data Quantized data buffer. +/// @param[in] scales Quantization scales. +/// @param[in] zero_points (Optional) Quantization zero points. +/// @param[in] src_dt Quantized data type. +/// @param[in] dst_dt Dequantized data type. +/// @param[in] method Quantization method. +/// @param[in] height Number of rows. +/// @param[in] width Number of columns. +/// +/// @return The dequantized data buffer. +std::vector dequantize( + const void* data, const void* scales, const void* zero_points, // + DataType src_dt, DataType dst_dt, QuantizationMethod method, // + size_t height, size_t width); + +} // namespace kai::test diff --git a/test/reference/reduce.cpp b/test/reference/reduce.cpp new file mode 100644 index 00000000..dff4cf69 --- /dev/null +++ b/test/reference/reduce.cpp @@ -0,0 +1,113 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/reference/reduce.hpp" + +#include +#include +#include + +#include "src/kai_common.h" +#include "test/common/data_format.hpp" +#include "test/common/data_type.hpp" +#include "test/common/int4.hpp" +#include "test/common/memory.hpp" + +namespace kai::test { + +namespace { + +template +T scalar_reduce(T curr_value, T new_value) { + if constexpr (op == ReductionOperator::ADD) { + return curr_value + new_value; + } +} + +template +std::vector reduce_any_op_type(const void* src, size_t height, size_t width, size_t dimension) { + std::vector dst; + + switch (dimension) { + case 0: + dst.resize(height * size_in_bits / 8); + KAI_ASSUME(height * size_in_bits % 8 == 0); + + for (size_t y = 0; y < height; ++y) { + Output acc = read_array(src, y * width); + + for (size_t x = 1; x < width; ++x) { + Output value = read_array(src, y * width + x); + acc = scalar_reduce(acc, value); + } + + write_array(dst.data(), y, acc); + } + + break; + + case 1: + dst.resize(width * size_in_bits / 8); + KAI_ASSUME(width * size_in_bits % 8 == 0); + + for (size_t x = 0; x < width; ++x) { + Output acc = read_array(src, x); + + for (size_t y = 1; y < height; ++y) { + Output value = read_array(src, y * width + x); + acc = scalar_reduce(acc, value); + } + + write_array(dst.data(), x, acc); + } + + break; + + default: + KAI_ERROR("Only 2D data is supported!"); + } + + return dst; +} + +template +std::vector reduce_any_op( + const void* src, const DataFormat& src_format, size_t height, size_t width, const DataFormat& dst_format, + size_t dimension) { + KAI_ASSUME(src_format.is_raw()); + KAI_ASSUME(dst_format.is_raw()); + KAI_ASSUME(dimension < 2); + KAI_ASSUME(height > 0); + KAI_ASSUME(width > 0); + + const auto src_dt = src_format.data_type(); + const auto dst_dt = dst_format.data_type(); + + switch (src_dt) { + case DataType::QSU4: + switch (dst_dt) { + case DataType::I32: + return reduce_any_op_type(src, height, width, dimension); + break; + + default: + KAI_ERROR("Unsupported data type!"); + } + + default: + KAI_ERROR("Unsupported data type!"); + } +} + +} // namespace + +std::vector reduce_add( + const void* src, const DataFormat& src_format, size_t height, size_t width, const DataFormat& dst_format, + size_t dimension) { + return reduce_any_op(src, src_format, height, width, dst_format, dimension); +} + +} // namespace kai::test diff --git a/test/reference/reduce.hpp b/test/reference/reduce.hpp new file mode 100644 index 00000000..f6ba197a --- /dev/null +++ b/test/reference/reduce.hpp @@ -0,0 +1,36 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include + +namespace kai::test { + +class DataFormat; + +/// Reduction operator. +enum class ReductionOperator : uint32_t { + ADD, ///< Addition. +}; + +/// Reduces the matrix value using addition. +/// +/// @param[in] src Input data. +/// @param[in] src_format Input data format. +/// @param[in] height Number of rows. +/// @param[in] width Number of columns. +/// @param[in] dst_foramt Output data format. +/// @param[in] dimension Reduction dimension. +/// +/// @return The reduced matrix. +std::vector reduce_add( + const void* src, const DataFormat& src_format, size_t height, size_t width, const DataFormat& dst_format, + size_t dimension); + +} // namespace kai::test -- GitLab From cee2862b84d34804969e4310b82f7670d1f223f2 Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Tue, 30 Apr 2024 13:57:06 +0100 Subject: [PATCH 04/11] Add matrix multiplication test * Add reference implementation for packing LHS, RHS, and the main matmul routine. * Add test for matrix multiplication: - Generic test is implemented for all matmul methods. - The test checks the arbitrary portion of the output matrix. - The test checks the LHS packing, RHS packing and the main matmul kernels. Signed-off-by: Viet-Hoa Do --- .clang-tidy | 4 +- CMakeLists.txt | 5 + test/reference/matmul.cpp | 154 ++++++++++++++ test/reference/matmul.hpp | 60 ++++++ test/reference/pack.cpp | 206 +++++++++++++++++++ test/reference/pack.hpp | 28 +++ test/tests/matmul_test.cpp | 412 +++++++++++++++++++++++++++++++++++++ 7 files changed, 868 insertions(+), 1 deletion(-) create mode 100644 test/reference/matmul.cpp create mode 100644 test/reference/matmul.hpp create mode 100644 test/reference/pack.cpp create mode 100644 test/reference/pack.hpp create mode 100644 test/tests/matmul_test.cpp diff --git a/.clang-tidy b/.clang-tidy index 4c0dfbe8..f05ad632 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -27,6 +27,8 @@ readability-*, -readability-simplify-boolean-expr, -bugprone-easily-swappable-parameters, -cppcoreguidelines-pro-bounds-pointer-arithmetic, --performance-enum-size +-performance-enum-size, +-llvm-else-after-return, +-readability-else-after-return, ' ... diff --git a/CMakeLists.txt b/CMakeLists.txt index b7e6786d..ca7a72af 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -63,11 +63,16 @@ if(KLEIDIAI_BUILD_TESTS) test/common/compare.cpp test/common/matrix_portion.cpp test/common/rect.cpp + test/reference/binary_elementwise.cpp + test/reference/matmul.cpp test/reference/fill.cpp + test/reference/pack.cpp test/reference/quantize.cpp test/reference/reduce.cpp test/reference/round.cpp + + test/tests/matmul_test.cpp ) target_include_directories(kleidiai_test diff --git a/test/reference/matmul.cpp b/test/reference/matmul.cpp new file mode 100644 index 00000000..6b2ab041 --- /dev/null +++ b/test/reference/matmul.cpp @@ -0,0 +1,154 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/reference/matmul.hpp" + +#include +#include +#include + +#include "src/kai_common.h" +#include "test/common/data_format.hpp" +#include "test/common/data_type.hpp" +#include "test/common/int4.hpp" +#include "test/common/memory.hpp" +#include "test/common/printer.hpp" +#include "test/reference/binary_elementwise.hpp" +#include "test/reference/pack.hpp" +#include "test/reference/quantize.hpp" +#include "test/reference/reduce.hpp" + +namespace kai::test { + +namespace { + +/// Matrix multiplication. +/// +/// @tparam T Data type. +/// +/// @param[in] lhs LHS operand data buffer. +/// @param[in] rhs RHS operand data buffer. +/// @param[in] m Output height. +/// @param[in] n Output width. +/// @param[in] k Non-transposed LHS width and non-transposed RHS height. +/// @param[in] lhs_transposed `true` if LHS operand is transposed. +/// @param[in] rhs_transposed `true` if RHS operand is transposed. +/// +/// @return The result data buffer. +template +std::vector matmul_any_type( + const void* lhs, const void* rhs, // + size_t m, size_t n, size_t k, // + bool lhs_transposed, bool rhs_transposed) { + const auto lhs_m_stride = lhs_transposed ? 1 : k; + const auto lhs_k_stride = lhs_transposed ? m : 1; + + const auto rhs_n_stride = rhs_transposed ? k : 1; + const auto rhs_k_stride = rhs_transposed ? 1 : n; + + std::vector dst; + dst.resize(m * n * size_in_bits / 8); + KAI_ASSUME(n * size_in_bits % 8 == 0); + + for (size_t im = 0; im < m; ++im) { + for (size_t in = 0; in < n; ++in) { + T acc{0}; + + for (size_t ik = 0; ik < k; ++ik) { + const auto lhs_value = read_array(lhs, im * lhs_m_stride + ik * lhs_k_stride); + const auto rhs_value = read_array(rhs, in * rhs_n_stride + ik * rhs_k_stride); + acc += lhs_value * rhs_value; + } + + write_array(dst.data(), im * n + in, acc); + } + } + + return dst; +} + +} // namespace + +std::vector matmul_pack_rhs( + const void* data, const void* scales, const void* zero_points, const DataFormat& src_format, + const DataFormat& dst_format, size_t height, size_t width) { + const auto src_dt = src_format.data_type(); + const auto src_qf = src_format.quantization_format(); + + const auto dst_dt = dst_format.data_type(); + const auto dst_qf = dst_format.quantization_format(); + + std::vector tmp_data; + std::vector tmp_scales; + std::vector tmp_zero_points; + + if (src_dt == DataType::QSU4 && src_qf == DataFormat::QuantizationFormat::NONE && // + dst_dt == DataType::QSI4 && dst_qf == DataFormat::QuantizationFormat::PER_ROW) { + // For this specific RHS format conversion: + // + // * 4-bit data is added by 8. + // * Scale is divided by 16. + // * Zero point is accumulation of all values in the same row. + + KAI_ASSUME(zero_points == nullptr); + const int32_t zero_point = 8; + const uint8_t zero_point_i4 = UInt4::pack_u8(UInt4(zero_point), UInt4(zero_point)); + const int32_t row_zero_point = zero_point * static_cast(width); + + KAI_ASSUME(dst_format.subblock_width() > 0); + const auto subblock_width_i32 = static_cast(dst_format.subblock_width()); + const auto subblock_width_f = static_cast(dst_format.subblock_width()); + + tmp_zero_points = reduce_add(data, src_format, height, width, DataFormat(DataType::I32), 0); + tmp_zero_points = sub(tmp_zero_points.data(), DataType::I32, height, 1, &row_zero_point, DataType::I32, 1, 1); + tmp_zero_points = + mul(tmp_zero_points.data(), DataType::I32, height, 1, &subblock_width_i32, DataType::I32, 1, 1); + zero_points = tmp_zero_points.data(); + + tmp_data = add(data, DataType::QSU4, height, width, &zero_point_i4, DataType::QSU4, 1, 1); + data = tmp_data.data(); + + tmp_scales = div(scales, DataType::FP32, height, 1, &subblock_width_f, DataType::FP32, 1, 1); + scales = tmp_scales.data(); + } + + return pack(dst_format, data, scales, zero_points, src_format, height, width); +} + +std::vector matmul( + const void* lhs, const void* lhs_scales, const void* lhs_zero_points, DataType lhs_dt, // + const void* rhs, const void* rhs_scales, const void* rhs_zero_points, DataType rhs_dt, // + DataType dst_dt, // + size_t m, size_t n, size_t k, // + bool lhs_transposed, bool rhs_transposed) { + const auto lhs_h = lhs_transposed ? k : m; + const auto lhs_w = lhs_transposed ? m : k; + + const auto rhs_h = rhs_transposed ? n : k; + const auto rhs_w = rhs_transposed ? k : n; + + std::vector tmp_lhs; + std::vector tmp_rhs; + + if (data_type_is_quantized(lhs_dt)) { + tmp_lhs = dequantize( + lhs, lhs_scales, lhs_zero_points, lhs_dt, DataType::FP32, QuantizationMethod::PER_MATRIX, lhs_h, lhs_w); + lhs = tmp_lhs.data(); + } + + if (data_type_is_quantized(rhs_dt)) { + tmp_rhs = dequantize( + rhs, rhs_scales, rhs_zero_points, rhs_dt, DataType::FP32, QuantizationMethod::PER_ROW, rhs_h, rhs_w); + rhs = tmp_rhs.data(); + } + + KAI_ASSUME(dst_dt == DataType::FP32); + const auto tmp_dst = matmul_any_type(lhs, rhs, m, n, k, lhs_transposed, rhs_transposed); + + return tmp_dst; +} + +} // namespace kai::test diff --git a/test/reference/matmul.hpp b/test/reference/matmul.hpp new file mode 100644 index 00000000..7dfca92c --- /dev/null +++ b/test/reference/matmul.hpp @@ -0,0 +1,60 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include + +#include "test/common/data_type.hpp" + +namespace kai::test { + +class DataFormat; + +/// Packs the RHS operand of matrix multiplication. +/// +/// @param[in] data Data buffer. +/// @param[in] scales (Optional) Quantization scales. +/// @param[in] zero_points (Optional) Quantization zero points. +/// @param[in] src_format Data format of the RHS matrix. +/// @param[in] dst_format Data format of the packed RHS matrix. +/// @param[in] height Number of rows. +/// @param[in] width Number of columns. +/// +/// @return The packed RHS matrix. +std::vector matmul_pack_rhs( + const void* data, const void* scales, const void* zero_points, const DataFormat& src_format, + const DataFormat& dst_format, size_t height, size_t width); + +/// Matrix multiplication. +/// +/// @param[in] lhs LHS operand data. +/// @param[in] lhs_scales (Optional) LHS operand quantization scales. +/// @param[in] lhs_zero_points (Optional) LHS operand quantization zero point. +/// @param[in] lhs_dt LHS operand data type. +/// @param[in] dst LHS operand data. +/// @param[in] dst_scales (Optional) LHS operand quantization scales. +/// @param[in] dst_zero_points (Optional) LHS operand quantization zero point. +/// @param[in] dst_dt LHS operand data type. +/// @param[in] dst_dt Output data type. +/// @param[in] m Output height. +/// @param[in] n Output width. +/// @param[in] k Non-transposed LHS width and non-transposed RHS height. +/// @param[in] lhs_transposed `true` if LHS operand is transposed. +/// @param[in] rhs_transposed `true` if RHS operand is transposed. +/// +/// @return The result data buffer. +std::vector matmul( + const void* lhs, const void* lhs_scales, const void* lhs_zero_points, DataType lhs_dt, // + const void* rhs, const void* rhs_scales, const void* rhs_zero_points, DataType rhs_dt, // + DataType dst_dt, // + size_t m, size_t n, size_t k, // + bool lhs_transposed, bool rhs_transposed); + +} // namespace kai::test diff --git a/test/reference/pack.cpp b/test/reference/pack.cpp new file mode 100644 index 00000000..11bc583b --- /dev/null +++ b/test/reference/pack.cpp @@ -0,0 +1,206 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/reference/pack.hpp" + +#include +#include +#include +#include +#include + +#include "src/kai_common.h" +#include "test/common/data_format.hpp" +#include "test/common/data_type.hpp" +#include "test/reference/quantize.hpp" + +namespace kai::test { + +namespace { + +/// Packs the matrix from raw to quantized format. +template +std::vector pack_quant_per_row( + const void* src, size_t height, size_t width, size_t block_height, size_t block_width) { + const auto num_groups = (height + block_height - 1) / block_height; + const auto group_num_blocks = (width + block_width - 1) / block_width; + + const auto group_zero_points_bytes = block_height * sizeof(ZeroPoint); + const auto group_scales_bytes = block_height * sizeof(Scale); + const auto block_data_bytes = block_height * block_width * sizeof(Output); + const auto group_bytes = group_zero_points_bytes + group_num_blocks * block_data_bytes + group_scales_bytes; + const auto dst_bytes = num_groups * group_bytes; + + std::vector dst; + dst.resize(dst_bytes); + + const auto* src_ptr = reinterpret_cast(src); + auto* dst_ptr = dst.data(); + + std::vector scales; + scales.resize(block_height); + + std::vector zero_points; + zero_points.resize(block_height); + + for (size_t group_no = 0; group_no < num_groups; ++group_no) { + // Finds the range of values and calculates the quantization information. + for (size_t y = 0; y < block_height; ++y) { + auto min_value = std::numeric_limits::max(); + auto max_value = std::numeric_limits::lowest(); + + for (size_t x = 0; x < width; ++x) { + const auto value = src_ptr[(group_no * block_height + y) * width + x]; + + if (value < min_value) { + min_value = value; + } + + if (value > max_value) { + max_value = value; + } + } + + std::tie(scales[y], zero_points[y]) = get_qi8_scale_zero_point_from_range(min_value, max_value); + } + + // Packs the zero points. + memcpy(dst_ptr, zero_points.data(), group_zero_points_bytes); + dst_ptr += group_zero_points_bytes; + + // Quantizes and packs the data. + for (size_t x_block = 0; x_block < group_num_blocks; ++x_block) { + for (size_t block_y = 0; block_y < block_height; ++block_y) { + for (size_t block_x = 0; block_x < block_width; ++block_x) { + const auto value = + src_ptr[(group_no * block_height + block_y) * width + x_block * block_width + block_x]; + const auto qvalue = quantize_i8_fp32(value, scales[block_y], zero_points[block_y]); + *reinterpret_cast(dst_ptr) = qvalue; + ++dst_ptr; + } + } + } + + // Packs the scales. + memcpy(dst_ptr, scales.data(), group_scales_bytes); + dst_ptr += group_scales_bytes; + } + + KAI_ASSERT(reinterpret_cast(dst_ptr) - reinterpret_cast(dst.data()) == dst_bytes); + + return dst; +} + +/// Packs the matrix with per-row quantized format. +/// +/// The source matrix is per-row quantized with separate quantization scale and zero-points data buffer. +/// The destination data is per-row quantized with blocking and embedded quantization information. +std::vector pack_per_row_qs4( + const void* src, const void* scales, const void* zero_points, size_t height, size_t width, size_t block_height, + size_t block_width, size_t subblock_height, size_t subblock_width) { + // Number of elements in a sub-block in vertical and horizontal axes. + const auto num_element_rows = subblock_height; + const auto num_element_cols = subblock_width; + const auto src_element_row_stride = width / 2; + + // Number of sub-blocks in a block in vertical and horizontal axes. + const auto num_subblock_rows = block_height / subblock_height; + const auto num_subblock_cols = block_width / subblock_width; + const auto src_subblock_col_stride = subblock_width / 4; + const auto src_subblock_row_stride = subblock_height * width / 2; + + // Number of blocks in the matrix in vertical and horizontal axes. + const auto num_block_rows = (height + block_height - 1) / block_height; + const auto num_block_cols = (width + block_width - 1) / block_width; + const auto src_block_col_stride = block_width / 2; + const auto src_block_row_stride = block_height * width / 2; + + const auto dst_block_row_scales_bytes = block_height * sizeof(float); + const auto dst_block_row_zero_points_bytes = block_height * sizeof(int32_t); + const auto dst_block_row_data_bytes = num_block_cols * block_height * block_width / 2; + const auto dst_bytes = + num_block_rows * (dst_block_row_zero_points_bytes + dst_block_row_data_bytes + dst_block_row_scales_bytes); + + std::vector dst; + dst.resize(dst_bytes); + + const auto* src_ptr = reinterpret_cast(src); + const auto* scales_ptr = reinterpret_cast(scales); + const auto* zero_points_ptr = reinterpret_cast(zero_points); + auto* dst_ptr = dst.data(); + + for (size_t block_row = 0; block_row < num_block_rows; ++block_row) { + if (zero_points_ptr != nullptr) { + memcpy(dst_ptr, zero_points_ptr + block_row * block_height, dst_block_row_zero_points_bytes); + } + + dst_ptr += dst_block_row_zero_points_bytes; + + for (size_t block_col = 0; block_col < num_block_cols; ++block_col) { + for (size_t subblock_col = 0; subblock_col < num_subblock_cols; ++subblock_col) { + for (size_t subblock_row = 0; subblock_row < num_subblock_rows; ++subblock_row) { + for (size_t element_col = 0; element_col < num_element_cols / 4; ++element_col) { + for (size_t element_row = 0; element_row < num_element_rows; ++element_row) { + const auto byte_lo = src_ptr[ // + block_row * src_block_row_stride + block_col * src_block_col_stride + + subblock_row * src_subblock_row_stride + subblock_col * src_subblock_col_stride + + element_row * src_element_row_stride + element_col]; + const auto byte_hi = src_ptr[ // + block_row * src_block_row_stride + block_col * src_block_col_stride + + subblock_row * src_subblock_row_stride + subblock_col * src_subblock_col_stride + + element_row * src_element_row_stride + element_col + block_width / 4]; + + const auto packed_byte0 = (byte_lo & 0x0F) | (byte_hi << 4); + const auto packed_byte1 = (byte_lo >> 4) | (byte_hi & 0xF0); + + dst_ptr[0] = packed_byte0; // ^ 0x88; + dst_ptr[1] = packed_byte1; // ^ 0x88; + dst_ptr += 2; + } + } + } + } + } + + if (scales_ptr != nullptr) { + memcpy(dst_ptr, scales_ptr + block_row * block_height, dst_block_row_scales_bytes); + } + dst_ptr += dst_block_row_scales_bytes; + } + + KAI_ASSERT(reinterpret_cast(dst_ptr) - reinterpret_cast(dst.data()) == dst_bytes); + + return dst; +} + +} // namespace + +std::vector pack( + const DataFormat& dst_format, const void* src, const void* scales, const void* zero_points, + const DataFormat& src_format, size_t height, size_t width) { + const auto dst_dt = dst_format.data_type(); + const auto dst_qf = dst_format.quantization_format(); + const auto src_dt = src_format.data_type(); + const auto src_qf = src_format.quantization_format(); + + if (src_qf == DataFormat::QuantizationFormat::NONE && dst_qf == DataFormat::QuantizationFormat::PER_ROW) { + if (dst_dt == DataType::QI8 && src_dt == DataType::FP32 && dst_format.scale_data_type() == DataType::FP32 && + dst_format.zero_point_data_type() == DataType::I32) { + return pack_quant_per_row( + src, height, width, dst_format.block_height(), dst_format.block_width()); + } else if ( + dst_dt == DataType::QSI4 && src_dt == DataType::QSU4 && dst_format.scale_data_type() == DataType::FP32 && + dst_format.zero_point_data_type() == DataType::I32) { + return pack_per_row_qs4( + src, scales, zero_points, height, width, dst_format.block_height(), dst_format.block_width(), + dst_format.subblock_height(), dst_format.subblock_width()); + } + } + + KAI_ERROR("Unsupported operation!"); +} + +} // namespace kai::test diff --git a/test/reference/pack.hpp b/test/reference/pack.hpp new file mode 100644 index 00000000..43362009 --- /dev/null +++ b/test/reference/pack.hpp @@ -0,0 +1,28 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include + +namespace kai::test { + +class DataFormat; + +/// Packs the matrix. +/// +/// @param[in] dst_format Data format of the destination matrix. +/// @param[in] src Data buffer of the source matrix. +/// @param[in] src_format Data format of the source matrix. +/// @param[in] height Number of rows of the source matrix. +/// @param[in] width Number of columns of the source matrix. +std::vector pack( + const DataFormat& dst_format, const void* src, const void* scales, const void* zero_points, + const DataFormat& src_format, size_t height, size_t width); + +} // namespace kai::test diff --git a/test/tests/matmul_test.cpp b/test/tests/matmul_test.cpp new file mode 100644 index 00000000..b53e2a85 --- /dev/null +++ b/test/tests/matmul_test.cpp @@ -0,0 +1,412 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/reference/matmul.hpp" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/kai_common.h" +#include "test/common/compare.hpp" +#include "test/common/data_format.hpp" +#include "test/common/data_type.hpp" +#include "test/common/matrix_portion.hpp" +#include "test/common/printer.hpp" +#include "test/reference/fill.hpp" +#include "test/reference/pack.hpp" + +namespace kai::test { + +// NOLINTBEGIN(misc-non-private-member-variables-in-classes) + +/// Matrix multiplication method. +struct MatMulMethod { + size_t m0; ///< Block size in M dimension. + size_t n0; ///< Block size in N dimension. + size_t k0; ///< Block size in K dimension. + + bool lhs_transposed; ///< LHS matrix is transposed. + bool rhs_transposed; ///< RHS matrix is transposed. + + DataFormat dst_format; ///< Data format of the destination matrix. + DataFormat lhs_format; ///< Data format of the LHS matrix. + DataFormat packed_lhs_format; ///< Data format of the packed LHS matrix. + DataFormat rhs_format; ///< Data format of the RHS matrix. + DataFormat packed_rhs_format; ///< Data for mat of the packed RHS matrix. + + /// Gets the offset in bytes of the LHS matrix. + /// + /// @param[in] m_idx Coordinate of the matrix in M dimension. + /// @param[in] stride Row stride in bytes. + /// + /// @return The offset in bytes. + std::function fn_get_lhs_offset; + + /// Gets the size in bytes of the packed LHS matrix. + /// + /// @param[in] m Size of the matrix in M dimension. + /// @param[in] k Size of the matrix in K dimension. + /// + /// @return The size in bytes. + std::function fn_get_packed_lhs_size; + + /// Gets the offset in bytes of the packed LHS matrix. + /// + /// @param[in] m_idx Coordinate of the matrix in M dimension. + /// @param[in] k Size of the matrix in K dimension. + /// + /// @return The offset in bytes. + std::function fn_get_packed_lhs_offset; + + /// Preprocesses the LHS matrix. + /// + /// @param[in] m Size of the matrix in M dimension. + /// @param[in] k Size of the matrix in K dimension. + /// @param[in] lhs LHS matrix data buffer. + /// @param[in] lhs_row_stride Row stride in bytes of the LHS matrix. + /// @param[out] packed_lhs Packed LHS matrix data buffer. + std::function fn_pack_lhs; + + /// Gets the offset in bytes of the RHS matrix. + /// + /// @param[in] n_idx Coordinate of the matrix in N dimension. + /// @param[in] stride Row stride in bytes. + /// + /// @return The offset in bytes. + std::function fn_get_rhs_offset; + + /// Gets the size in bytes of the packed RHS matrix. + /// + /// @param[in] n Size of the matrix in N dimension. + /// @param[in] k Size of the matrix in K dimension. + /// @param[in] block_height Block height. + /// @param[in] block_width Block width. + /// + /// @return The size in bytes. + std::function fn_get_packed_rhs_size; + + /// Gets the offset in bytes of the packed RHS matrix. + /// + /// @param[in] n_idx Coordinate of the matrix in N dimension. + /// @param[in] k Size of the matrix in K dimension. + /// @param[in] block_height Block height. + /// @param[in] block_width Block width. + /// + /// @return The offset in bytes. + std::function fn_get_packed_rhs_offset; + + /// Performs matrix multiplication. + /// + /// @param[in] m Size of the matrix in M dimension. + /// @param[in] n Size of the matrix in N dimension. + /// @param[in] k Size of the matrix in K dimension. + /// @param[in] lhs_p Packed LHS data buffer. + /// @param[in] rhs_p Packed RHS data buffer. + /// @param[out] dst Output data buffer. + /// @param[in] dst_stride_row Output row stride. + /// @param[in] dst_stride_col Output column stride. + /// @param[in] scalar_min Lower bound of the output data. + /// @param[in] scalar_max Upper bound of the output data. + std::function + fn_main; + + /// Gets a value indicating whether pre-processing the RHS matrix is needed. + [[nodiscard]] bool is_pack_rhs_needed() const { + return false; + } + + /// Preprocesses the RHS matrix. + /// + /// @param[in] n Size of the matrix in N dimension. + /// @param[in] k Size of the matrix in K dimension. + /// @param[in] rhs RHS data buffer. + /// @param[in] rhs_row_stride RHS row stride. + /// @param[in] bias Bias data buffer. + /// @param[in] scale Quantization scales data buffer. + /// @param[out] packed_rhs Packed RHS data buffer. + void pack_rhs( + size_t n, size_t k, const void* rhs, size_t rhs_row_stride, const void* bias, const void* scale, + void* packed_rhs) const { + KAI_UNUSED(n); + KAI_UNUSED(k); + KAI_UNUSED(rhs); + KAI_UNUSED(rhs_row_stride); + KAI_UNUSED(bias); + KAI_UNUSED(scale); + KAI_UNUSED(packed_rhs); + + KAI_ERROR("RHS pre-processing is not supported!"); + } +}; + +// NOLINTEND(misc-non-private-member-variables-in-classes) + +/// List of supported matrix multiplication methods. +static const std::array matmul_methods = {}; + +/// Matrix multiplication shape. +struct MatMulShape { + size_t m; ///< LHS height. + size_t n; ///< RHS width. + size_t k; ///< LHS width and RHS height. +}; + +/// Matrix multiplication test information. +using MatMulTestParams = std::tuple; + +/// Prints the test information. +void PrintTo(const MatMulTestParams& param, std::ostream* os) { + const auto& [shape, method_no, portion] = param; + + *os << "m: " << shape.m << ", n: " << shape.n << ", k: " << shape.k << ", method_no: " << method_no + << ", portion: { start_row: " << portion.start_row() << ", start_col: " << portion.start_col() + << ", height: " << portion.height() << ", width: " << portion.width() << "}"; +} + +/// Matrix multiplication test fixture. +class MatMulTest : public testing::TestWithParam { +private: + /// Unique ID: m, n, k, method_id. + using TestDataId = std::tuple; + +protected: + /// Cached test data that is shared between multiple test case. + struct TestData { + std::vector lhs; ///< LHS operand. + std::vector ref_packed_lhs; ///< Reference packed LHS. + std::vector rhs; ///< RHS operand. + std::vector rhs_scales; ///< RHS per-row quantization scales. + std::vector ref_packed_rhs; ///< Reference packed RHS. + std::vector ref_dst; ///< Reference output. + }; + + /// Gets the test data for the current test case. + static const TestData& test_data() { + const auto& [info, method_no, portion] = GetParam(); + const TestDataId data_id{info.m, info.n, info.k, method_no}; + + // If the test data is already available, returns it. + const auto data_it = _data.find(data_id); + + if (data_it != _data.end()) { + return data_it->second; + } + + // Generates the test data. + const auto& method = matmul_methods.at(method_no); + + const auto lhs_h = method.lhs_transposed ? info.k : info.m; + const auto lhs_w = method.lhs_transposed ? info.m : info.k; + auto lhs = fill_matrix_random(lhs_h, lhs_w, method.lhs_format, 0); + auto ref_packed_lhs = + pack(method.packed_lhs_format, lhs.data(), nullptr, nullptr, method.lhs_format, lhs_h, lhs_w); + + const auto rhs_h = method.rhs_transposed ? info.n : info.k; + const auto rhs_w = method.rhs_transposed ? info.k : info.n; + auto rhs = fill_matrix_random(rhs_h, rhs_w, method.rhs_format, 1); + + std::vector rhs_scales; + if (data_type_is_quantized(method.rhs_format.data_type()) && + method.rhs_format.quantization_format() == DataFormat::QuantizationFormat::NONE) { + rhs_scales = fill_matrix_random(rhs_h, 1, DataFormat(DataType::FP32), 2); + } + + auto packed_rhs = matmul_pack_rhs( + rhs.data(), !rhs_scales.empty() ? rhs_scales.data() : nullptr, nullptr, method.rhs_format, + method.packed_rhs_format, rhs_h, rhs_w); + + KAI_ASSUME(method.lhs_format.is_raw()); + KAI_ASSUME(method.rhs_format.is_raw()); + KAI_ASSUME(method.dst_format.is_raw()); + auto ref_dst = matmul( + lhs.data(), nullptr, nullptr, method.lhs_format.data_type(), // + rhs.data(), rhs_scales.data(), nullptr, method.rhs_format.data_type(), // + method.dst_format.data_type(), // + info.m, info.n, info.k, method.lhs_transposed, method.rhs_transposed); + + const auto& data = _data[data_id] = { + .lhs = std::move(lhs), + .ref_packed_lhs = std::move(ref_packed_lhs), + .rhs = std::move(rhs), + .rhs_scales = std::move(rhs_scales), + .ref_packed_rhs = std::move(packed_rhs), + .ref_dst = std::move(ref_dst), + }; + + return data; + } + +private: + static std::map _data; +}; + +std::map MatMulTest::_data; + +/// Tests the LHS packing kernel. +TEST_P(MatMulTest, PackedLhs) { + const auto& [info, method_no, portion] = GetParam(); + const auto& data = test_data(); + const auto& method = matmul_methods.at(method_no); + + if (method.fn_pack_lhs == nullptr) { + GTEST_SKIP(); + } + + const auto lhs_h = method.lhs_transposed ? info.k : info.m; + const auto lhs_w = method.lhs_transposed ? info.m : info.k; + + const auto rect = portion.compute_portion( + lhs_h, lhs_w, method.packed_lhs_format.scheduler_block_height(lhs_h), + method.packed_lhs_format.scheduler_block_width(lhs_w)); + + if (rect.height() == 0 || rect.width() == 0) { + GTEST_SKIP(); + } + + const auto ref_lhs_row_stride = method.lhs_format.default_row_stride(lhs_w); + + const auto packed_lhs_size = method.fn_get_packed_lhs_size(info.m, info.k); + const auto ref_packed_lhs_size = method.packed_lhs_format.default_size_in_bytes(lhs_h, lhs_w); + ASSERT_EQ(packed_lhs_size, ref_packed_lhs_size); + + const auto lhs_offset = method.fn_get_lhs_offset(rect.start_row(), ref_lhs_row_stride); + const auto ref_lhs_offset = method.lhs_format.default_offset_in_bytes(rect.start_row(), 0, lhs_w); + ASSERT_EQ(lhs_offset, ref_lhs_offset); + + const auto packed_lhs_offset = method.fn_get_packed_lhs_offset(rect.start_row(), info.k); + const auto ref_packed_lhs_offset = method.packed_lhs_format.default_offset_in_bytes(rect.start_row(), 0, lhs_w); + ASSERT_EQ(packed_lhs_offset, ref_packed_lhs_offset); + + std::vector packed_lhs; + packed_lhs.resize(packed_lhs_size); + method.fn_pack_lhs( + rect.height(), rect.width(), data.lhs.data() + lhs_offset, ref_lhs_row_stride, + packed_lhs.data() + packed_lhs_offset); + + DefaultMismatchHandler handler(0, 0.0001, 0, 0.001); + const auto success = + compare(packed_lhs.data(), data.ref_packed_lhs.data(), method.packed_lhs_format, lhs_h, lhs_w, rect, handler); + ASSERT_TRUE(success); +} + +/// Tests the RHS packing kernel. +TEST_P(MatMulTest, PackedRhs) { + const auto& [info, method_no, portion] = GetParam(); + const auto& data = test_data(); + const auto& method = matmul_methods.at(method_no); + + if (!method.is_pack_rhs_needed()) { + GTEST_SKIP(); + } + + const auto rhs_h = method.rhs_transposed ? info.n : info.k; + const auto rhs_w = method.rhs_transposed ? info.k : info.n; + + const auto rect = portion.compute_portion( + rhs_h, rhs_w, method.packed_rhs_format.scheduler_block_height(rhs_h), + method.packed_rhs_format.scheduler_block_width(rhs_w)); + + if (rect.height() == 0 || rect.width() == 0) { + GTEST_SKIP(); + } + + const auto ref_rhs_row_stride = method.rhs_format.default_row_stride(rhs_w); + + const auto rhs_offset = method.fn_get_rhs_offset(rect.start_row(), ref_rhs_row_stride); + const auto ref_rhs_offset = method.rhs_format.default_offset_in_bytes(rect.start_row(), rect.start_col(), rhs_w); + ASSERT_EQ(rhs_offset, ref_rhs_offset); + + const auto packed_rhs_size = method.fn_get_packed_rhs_size( + rhs_h, rhs_w, method.packed_rhs_format.block_height(), method.packed_rhs_format.block_width()); + const auto ref_packed_rhs_size = method.packed_rhs_format.default_size_in_bytes(rhs_h, rhs_w); + ASSERT_EQ(packed_rhs_size, ref_packed_rhs_size); + + const auto packed_rhs_offset = method.fn_get_packed_rhs_offset( + rect.start_row(), rhs_w, method.packed_rhs_format.block_height(), method.packed_rhs_format.block_width()); + const auto ref_packed_rhs_offset = + method.packed_rhs_format.default_offset_in_bytes(rect.start_row(), rect.start_col(), rhs_w); + ASSERT_EQ(packed_rhs_offset, ref_packed_rhs_offset); + + const auto ref_rhs_scales_offset = + rect.start_row() * data_type_size_in_bits(method.packed_rhs_format.scale_data_type()) / 8; + + std::vector packed_rhs; + packed_rhs.resize(packed_rhs_size); + + method.pack_rhs( + rect.height(), rect.width(), data.rhs.data() + rhs_offset, ref_rhs_row_stride, nullptr, + !data.rhs_scales.empty() ? data.rhs_scales.data() + ref_rhs_scales_offset : nullptr, + packed_rhs.data() + packed_rhs_offset); + + DefaultMismatchHandler handler(0, 0.0001, 0, 0.001); + const auto success = + compare(packed_rhs.data(), data.ref_packed_rhs.data(), method.packed_rhs_format, rhs_h, rhs_w, rect, handler); + ASSERT_TRUE(success); +} + +/// Tests the output. +TEST_P(MatMulTest, Output) { + const auto& [info, method_no, portion] = GetParam(); + const auto& data = test_data(); + const auto& method = matmul_methods.at(method_no); + + const auto rect = portion.compute_portion(info.m, info.n, method.m0, method.n0); + + if (rect.height() == 0 || rect.width() == 0) { + GTEST_SKIP(); + } + + const auto ref_dst_row_stride = method.dst_format.default_row_stride(info.n); + const auto ref_dst_col_stride = data_type_size_in_bits(method.dst_format.data_type()) / 8; + + const auto ref_packed_lhs_offset = method.packed_lhs_format.default_offset_in_bytes( + method.lhs_transposed ? 0 : rect.start_row(), method.lhs_transposed ? rect.start_row() : 0, + method.lhs_transposed ? info.m : info.k); + const auto ref_packed_rhs_offset = method.packed_rhs_format.default_offset_in_bytes( + method.rhs_transposed ? rect.start_col() : 0, method.rhs_transposed ? 0 : rect.start_col(), + method.rhs_transposed ? info.k : info.n); + const auto ref_dst_offset = method.dst_format.default_offset_in_bytes(rect.start_row(), rect.start_col(), info.n); + + std::vector dst; + dst.resize(method.dst_format.default_size_in_bytes(info.m, info.n)); + + method.fn_main( + rect.height(), rect.width(), info.k, data.ref_packed_lhs.data() + ref_packed_lhs_offset, + data.ref_packed_rhs.data() + ref_packed_rhs_offset, reinterpret_cast(dst.data() + ref_dst_offset), + ref_dst_row_stride, ref_dst_col_stride, std::numeric_limits::lowest(), + std::numeric_limits::max()); + + DefaultMismatchHandler handler(0, 0.1, 0, 0.05); + const auto success = compare(dst.data(), data.ref_dst.data(), method.dst_format, info.m, info.n, rect, handler); + ASSERT_TRUE(success); +} + +INSTANTIATE_TEST_SUITE_P( + MatMul, MatMulTest, + testing::Combine( + testing::Values( + MatMulShape{4, 4, 32}, // + MatMulShape{12, 16, 64}), + testing::Range(0, matmul_methods.size()), + testing::Values( + MatrixPortion(0, 0, 1, 1), // Full matrix. + MatrixPortion(0, 0, 0.25, 0.25), // Top-left corner. + MatrixPortion(0.75, 0.75, 1, 1) // Bottom-right corner. + ))); + +} // namespace kai::test -- GitLab From 2ec5fc6a71a3ba517c4bebe403e56fe03f45e86f Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Tue, 30 Apr 2024 16:50:45 +0100 Subject: [PATCH 05/11] Fix GCC compilation error Signed-off-by: Viet-Hoa Do --- test/common/compare.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/common/compare.cpp b/test/common/compare.cpp index 206ec0e5..72eba283 100644 --- a/test/common/compare.cpp +++ b/test/common/compare.cpp @@ -186,6 +186,8 @@ bool compare( break; } + break; + case DataFormat::QuantizationFormat::PER_ROW: if (data_type == DataType::QI8 && scale_dt == DataType::FP32 && offset_dt == DataType::I32) { return compare_per_row( -- GitLab From 5be25ff33fb098b03abc3ef391c98f12b24cadcc Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Thu, 2 May 2024 10:11:11 +0100 Subject: [PATCH 06/11] Change data type QI8 to QAI8 Signed-off-by: Viet-Hoa Do --- test/common/data_type.hpp | 2 +- test/common/printer.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/common/data_type.hpp b/test/common/data_type.hpp index 8fbdcf9d..c707b53f 100644 --- a/test/common/data_type.hpp +++ b/test/common/data_type.hpp @@ -41,7 +41,7 @@ enum class DataType : uint16_t { I32 = 0b1'1'0'0'0000'00100000, ///< 32-bit signed integer. - QI8 = 0b1'1'1'1'0000'00001000, ///< 8-bit asymmetric quantized. + QAI8 = 0b1'1'1'1'0000'00001000, ///< 8-bit unsigned asymmetric quantized. QSU4 = 0b1'0'1'0'0000'00000100, ///< 4-bit unsigned symmetric quantized. QSI4 = 0b1'1'1'0'0000'00000100, ///< 4-bit signed symmetric quantized. diff --git a/test/common/printer.cpp b/test/common/printer.cpp index 368ae7db..9bc7a100 100644 --- a/test/common/printer.cpp +++ b/test/common/printer.cpp @@ -40,7 +40,7 @@ inline void print_data(std::ostream& os, const uint8_t* data, size_t len, DataTy os << reinterpret_cast(data)[i]; break; - case DataType::QI8: + case DataType::QAI8: os << static_cast(reinterpret_cast(data)[i]); break; -- GitLab From 581eae238c55404a18546f14ec81b2e4321732c8 Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Thu, 2 May 2024 10:27:08 +0100 Subject: [PATCH 07/11] Address various comments Signed-off-by: Viet-Hoa Do --- test/common/compare.cpp | 2 +- test/common/matrix_portion.cpp | 19 ++++++++++--------- test/common/matrix_portion.hpp | 10 +++++----- 3 files changed, 16 insertions(+), 15 deletions(-) diff --git a/test/common/compare.cpp b/test/common/compare.cpp index 72eba283..f1dfecdf 100644 --- a/test/common/compare.cpp +++ b/test/common/compare.cpp @@ -189,7 +189,7 @@ bool compare( break; case DataFormat::QuantizationFormat::PER_ROW: - if (data_type == DataType::QI8 && scale_dt == DataType::FP32 && offset_dt == DataType::I32) { + if (data_type == DataType::QAI8 && scale_dt == DataType::FP32 && offset_dt == DataType::I32) { return compare_per_row( imp_data, ref_data, format, full_height, full_width, rect, handler); } else if (data_type == DataType::QSI4 && scale_dt == DataType::FP32 && offset_dt == DataType::I32) { diff --git a/test/common/matrix_portion.cpp b/test/common/matrix_portion.cpp index abf795df..e7c42bf1 100644 --- a/test/common/matrix_portion.cpp +++ b/test/common/matrix_portion.cpp @@ -9,6 +9,7 @@ #include #include +#include "src/kai_common.h" #include "test/common/rect.hpp" #include "test/reference/round.hpp" @@ -36,15 +37,15 @@ float MatrixPortion::width() const { Rect MatrixPortion::compute_portion( size_t full_height, size_t full_width, size_t scheduler_block_height, size_t scheduler_block_width) const { - const auto start_row_f = std::clamp(_start_row, 0, 1); - const auto start_col_f = std::clamp(_start_col, 0, 1); - const auto height_f = std::clamp(_height, 0, 1); - const auto width_f = std::clamp(_width, 0, 1); - - auto start_row = round_to_nearest_even_usize(start_row_f * static_cast(full_height)); - auto start_col = round_to_nearest_even_usize(start_col_f * static_cast(full_width)); - auto height = round_to_nearest_even_usize(height_f * static_cast(full_height)); - auto width = round_to_nearest_even_usize(width_f * static_cast(full_width)); + KAI_ASSUME(_start_row >= 0.0F && _start_row <= 1.0F); + KAI_ASSUME(_start_col >= 0.0F && _start_col <= 1.0F); + KAI_ASSUME(_height >= 0.0F && _height <= 1.0F); + KAI_ASSUME(_width >= 0.0F && _width <= 1.0F); + + auto start_row = round_to_nearest_even_usize(_start_row * static_cast(full_height)); + auto start_col = round_to_nearest_even_usize(_start_col * static_cast(full_width)); + auto height = round_to_nearest_even_usize(_height * static_cast(full_height)); + auto width = round_to_nearest_even_usize(_width * static_cast(full_width)); start_row = round_down_multiple(start_row, scheduler_block_height); start_col = round_down_multiple(start_col, scheduler_block_width); diff --git a/test/common/matrix_portion.hpp b/test/common/matrix_portion.hpp index 02f9f548..68193a40 100644 --- a/test/common/matrix_portion.hpp +++ b/test/common/matrix_portion.hpp @@ -14,7 +14,7 @@ namespace kai::test { /// Portion of a matrix. /// -/// This class is used to define the sub-matrix a test is running and checking. +/// This class is used to define the sub-matrix under test. /// /// This is the relative version of @ref Rect. class MatrixPortion { @@ -43,10 +43,10 @@ public: /// /// Requirements: /// - /// * The starting coordinate of the sub-matrix shall be aligned with the block boundary. - /// * If it is not the block at the right and/or bottom edge of the full matrix, the height and width - /// of the sub-matrix shall be rounded up to multiple of the block height and width. - /// * If it is the block at the right and/or bottom edge of the full matrix, the height and width + /// * The starting coordinate of the sub-matrix shall be aligned with the scheduling block boundary. + /// * If it is not the scheduling block at the right and/or bottom edge of the full matrix, the height and width + /// of the sub-matrix shall be rounded up to multiple of the scheduling block height and width. + /// * If it is the scheduling block at the right and/or bottom edge of the full matrix, the height and width /// of the sub-matrix shall be the rounded up to the edge of the matrix. /// /// @param[in] full_height Matrix height. -- GitLab From ba29325dfd51bc1619e0a9e20f8320e24f08deca Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Thu, 2 May 2024 10:39:18 +0100 Subject: [PATCH 08/11] Change qi8 to qai8 Signed-off-by: Viet-Hoa Do --- test/reference/quantize.cpp | 2 +- test/reference/quantize.hpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/reference/quantize.cpp b/test/reference/quantize.cpp index fab745e6..93d6e8ef 100644 --- a/test/reference/quantize.cpp +++ b/test/reference/quantize.cpp @@ -23,7 +23,7 @@ namespace kai::test { -std::tuple get_qi8_scale_zero_point_from_range(float min_value, float max_value) { +std::tuple get_qai8_scale_zero_point_from_range(float min_value, float max_value) { constexpr float q_min = std::numeric_limits::min(); constexpr float q_max = std::numeric_limits::max(); diff --git a/test/reference/quantize.hpp b/test/reference/quantize.hpp index 904ac5fc..b9865903 100644 --- a/test/reference/quantize.hpp +++ b/test/reference/quantize.hpp @@ -27,7 +27,7 @@ enum class QuantizationMethod : uint32_t { /// @param[in] max_value Maximum value. /// /// @return The scale and zero point. -std::tuple get_qi8_scale_zero_point_from_range(float min_value, float max_value); +std::tuple get_qai8_scale_zero_point_from_range(float min_value, float max_value); /// Quantizes the single-precision floating-point value using 8-bit asymmetric quantization. /// -- GitLab From b0e7b5c5a1cf3c75b8ff1b5e569653d29ea71404 Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Thu, 2 May 2024 10:42:47 +0100 Subject: [PATCH 09/11] Fix compilation error Signed-off-by: Viet-Hoa Do --- test/reference/pack.cpp | 4 ++-- test/tests/matmul_test.cpp | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/test/reference/pack.cpp b/test/reference/pack.cpp index 11bc583b..8a596ed2 100644 --- a/test/reference/pack.cpp +++ b/test/reference/pack.cpp @@ -64,7 +64,7 @@ std::vector pack_quant_per_row( } } - std::tie(scales[y], zero_points[y]) = get_qi8_scale_zero_point_from_range(min_value, max_value); + std::tie(scales[y], zero_points[y]) = get_qai8_scale_zero_point_from_range(min_value, max_value); } // Packs the zero points. @@ -187,7 +187,7 @@ std::vector pack( const auto src_qf = src_format.quantization_format(); if (src_qf == DataFormat::QuantizationFormat::NONE && dst_qf == DataFormat::QuantizationFormat::PER_ROW) { - if (dst_dt == DataType::QI8 && src_dt == DataType::FP32 && dst_format.scale_data_type() == DataType::FP32 && + if (dst_dt == DataType::QAI8 && src_dt == DataType::FP32 && dst_format.scale_data_type() == DataType::FP32 && dst_format.zero_point_data_type() == DataType::I32) { return pack_quant_per_row( src, height, width, dst_format.block_height(), dst_format.block_width()); diff --git a/test/tests/matmul_test.cpp b/test/tests/matmul_test.cpp index b53e2a85..f4c3582b 100644 --- a/test/tests/matmul_test.cpp +++ b/test/tests/matmul_test.cpp @@ -187,12 +187,12 @@ private: protected: /// Cached test data that is shared between multiple test case. struct TestData { - std::vector lhs; ///< LHS operand. - std::vector ref_packed_lhs; ///< Reference packed LHS. - std::vector rhs; ///< RHS operand. - std::vector rhs_scales; ///< RHS per-row quantization scales. - std::vector ref_packed_rhs; ///< Reference packed RHS. - std::vector ref_dst; ///< Reference output. + std::vector lhs{}; ///< LHS operand. + std::vector ref_packed_lhs{}; ///< Reference packed LHS. + std::vector rhs{}; ///< RHS operand. + std::vector rhs_scales{}; ///< RHS per-row quantization scales. + std::vector ref_packed_rhs{}; ///< Reference packed RHS. + std::vector ref_dst{}; ///< Reference output. }; /// Gets the test data for the current test case. -- GitLab From a5dae55cc32d2935e7902c600b8e7f1a2496e142 Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Thu, 2 May 2024 10:51:55 +0100 Subject: [PATCH 10/11] Add a dummy test to make the test suite passed Signed-off-by: Viet-Hoa Do --- test/tests/matmul_test.cpp | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/test/tests/matmul_test.cpp b/test/tests/matmul_test.cpp index f4c3582b..dcafa8e0 100644 --- a/test/tests/matmul_test.cpp +++ b/test/tests/matmul_test.cpp @@ -157,7 +157,35 @@ struct MatMulMethod { // NOLINTEND(misc-non-private-member-variables-in-classes) /// List of supported matrix multiplication methods. -static const std::array matmul_methods = {}; +static const std::array matmul_methods = { + MatMulMethod{ + .m0 = 4, + .n0 = 4, + .k0 = 32, + + .lhs_transposed = false, + .rhs_transposed = true, + + .dst_format = DataFormat(DataType::FP32), + .lhs_format = DataFormat(DataType::FP32), + .packed_lhs_format = + DataFormat(DataType::QAI8, 4, 8, DataFormat::QuantizationFormat::PER_ROW, DataType::FP32, DataType::I32), + .rhs_format = DataFormat(DataType::QSU4), + .packed_rhs_format = DataFormat( + DataType::QSI4, 4, 32, DataFormat::QuantizationFormat::PER_ROW, DataType::FP32, DataType::I32, 1, 16), + + .fn_get_lhs_offset = nullptr, + .fn_get_packed_lhs_size = nullptr, + .fn_get_packed_lhs_offset = nullptr, + .fn_pack_lhs = nullptr, + + .fn_get_rhs_offset = nullptr, + .fn_get_packed_rhs_size = nullptr, + .fn_get_packed_rhs_offset = nullptr, + + .fn_main = nullptr, + }, +}; /// Matrix multiplication shape. struct MatMulShape { @@ -365,6 +393,10 @@ TEST_P(MatMulTest, Output) { const auto& data = test_data(); const auto& method = matmul_methods.at(method_no); + if (method.fn_main == nullptr) { + GTEST_SKIP(); + } + const auto rect = portion.compute_portion(info.m, info.n, method.m0, method.n0); if (rect.height() == 0 || rect.width() == 0) { -- GitLab From 5b9a4443b2148180e4cd1577a8072dd30a6e5e46 Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Fri, 3 May 2024 10:16:46 +0100 Subject: [PATCH 11/11] Remove dummy test Signed-off-by: Viet-Hoa Do --- test/sample.cpp | 10 ---------- 1 file changed, 10 deletions(-) delete mode 100644 test/sample.cpp diff --git a/test/sample.cpp b/test/sample.cpp deleted file mode 100644 index d8e63383..00000000 --- a/test/sample.cpp +++ /dev/null @@ -1,10 +0,0 @@ -// -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates -// -// SPDX-License-Identifier: Apache-2.0 -// -#include - -TEST(SampleSuite, SampleTest) { - EXPECT_EQ(1 + 2, 3); -} -- GitLab