From 133a04448a87e74ca22d0101a3cd947fef7d918c Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Tue, 30 Apr 2024 12:42:29 +0100 Subject: [PATCH 1/5] Create test framework * Add common concepts: - Data type. - Data format. * Add custom data types: - 4-bit signed integer. - 4-bit unsigned integer. * Add generic type utilities: - Type traits. - Numeric limits. - Memory access. * Add matrix printer for arbitrary data format. * Add reference implementation of rounding functions. Signed-off-by: Viet-Hoa Do --- .clang-tidy | 5 ++ CMakeLists.txt | 34 ++++++-- src/kai_common.h | 42 ++++++++++ test/common/data_format.cpp | 144 +++++++++++++++++++++++++++++++++ test/common/data_format.hpp | 126 +++++++++++++++++++++++++++++ test/common/data_type.cpp | 81 +++++++++++++++++++ test/common/data_type.hpp | 110 +++++++++++++++++++++++++ test/common/int4.cpp | 98 ++++++++++++++++++++++ test/common/int4.hpp | 112 +++++++++++++++++++++++++ test/common/logging.hpp | 68 ++++++++++++++++ test/common/memory.hpp | 77 ++++++++++++++++++ test/common/numeric_limits.hpp | 39 +++++++++ test/common/printer.cpp | 112 +++++++++++++++++++++++++ test/common/printer.hpp | 28 +++++++ test/common/type_traits.hpp | 79 ++++++++++++++++++ test/reference/round.cpp | 36 +++++++++ test/reference/round.hpp | 62 ++++++++++++++ 17 files changed, 1246 insertions(+), 7 deletions(-) create mode 100644 src/kai_common.h create mode 100644 test/common/data_format.cpp create mode 100644 test/common/data_format.hpp create mode 100644 test/common/data_type.cpp create mode 100644 test/common/data_type.hpp create mode 100644 test/common/int4.cpp create mode 100644 test/common/int4.hpp create mode 100644 test/common/logging.hpp create mode 100644 test/common/memory.hpp create mode 100644 test/common/numeric_limits.hpp create mode 100644 test/common/printer.cpp create mode 100644 test/common/printer.hpp create mode 100644 test/common/type_traits.hpp create mode 100644 test/reference/round.cpp create mode 100644 test/reference/round.hpp diff --git a/.clang-tidy b/.clang-tidy index 7cff2942..bebfc13a 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -22,5 +22,10 @@ readability-*, -readability-identifier-length, -readability-magic-numbers, -readability-function-cognitive-complexity, +-cppcoreguidelines-pro-type-reinterpret-cast, +-cppcoreguidelines-avoid-magic-numbers, +-readability-simplify-boolean-expr, +-bugprone-easily-swappable-parameters, +-cppcoreguidelines-pro-bounds-pointer-arithmetic ' ... diff --git a/CMakeLists.txt b/CMakeLists.txt index d45d5070..b6d9662c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -31,9 +31,7 @@ endif() set(KLEIDIAI_WARNING_FLAGS "-Wall" - "-Wctor-dtor-privacy" "-Wdisabled-optimization" - "-Weffc++" "-Werror" "-Wextra" "-Wformat-security" @@ -42,19 +40,41 @@ set(KLEIDIAI_WARNING_FLAGS "-Wno-ignored-attributes" "-Wno-misleading-indentation" "-Wno-overlength-strings" - "-Woverloaded-virtual" - "-Wsign-promo" "-Wstrict-overflow=2" "-Wswitch-default" ) +set(KLEIDIAI_WARNING_FLAGS_CXX + "-Wctor-dtor-privacy" + "-Weffc++" + "-Woverloaded-virtual" + "-Wsign-promo" +) + if(KLEIDIAI_BUILD_TESTS) enable_testing() include(GoogleTest) - add_executable(kleidiai_test test/sample.cpp) - target_compile_options(kleidiai_test PRIVATE ${KLEIDIAI_WARNING_FLAGS}) - target_link_libraries(kleidiai_test PRIVATE GTest::gtest_main) + add_executable(kleidiai_test + test/common/data_type.cpp + test/common/data_format.cpp + test/common/printer.cpp + test/common/int4.cpp + test/reference/round.cpp + ) + + target_include_directories(kleidiai_test + PRIVATE . + ) + + target_compile_options(kleidiai_test + PRIVATE ${KLEIDIAI_WARNING_FLAGS} + PRIVATE ${KLEIDIAI_WARNING_FLAGS_CXX} + ) + + target_link_libraries(kleidiai_test + PRIVATE GTest::gtest_main + ) # Cross-compiling is a common use case which creates a conflict if DISCOVERY_MODE is set to POST_BUILD (by default) # since the host platform does not match the target. Setting the mode to PRE_TEST avoids this conflict. diff --git a/src/kai_common.h b/src/kai_common.h new file mode 100644 index 00000000..b44dad5b --- /dev/null +++ b/src/kai_common.h @@ -0,0 +1,42 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +// NOLINTBEGIN(cppcoreguidelines-avoid-do-while,cppcoreguidelines-pro-type-vararg,cert-err33-c) +// +// * cppcoreguidelines-avoid-do-while: do-while is necessary for macros. +// * cppcoreguidelines-pro-type-vararg: use of variadic arguments in fprintf is expected. +// * cert-err33-c: checking the output of fflush and fprintf is not necessary for error reporting. +#define KAI_ERROR(msg) \ + do { \ + fflush(stdout); \ + fprintf(stderr, "%s", msg); \ + exit(EXIT_FAILURE); \ + } while (0) + +#define KAI_ASSERT_MSG(cond, msg) \ + do { \ + if (!(cond)) { \ + KAI_ERROR(msg); \ + } \ + } while (0) +// NOLINTEND(cppcoreguidelines-avoid-do-while,cppcoreguidelines-pro-type-vararg,cert-err33-c) + +#define KAI_ASSERT(cond) KAI_ASSERT_MSG(cond, #cond) + +#define KAI_ASSERT_IF_MSG(precond, cond, msg) KAI_ASSERT_MSG(!(precond) || (cond), msg) +#define KAI_ASSERT_IF(precond, cond) KAI_ASSERT_IF_MSG(precond, cond, #precond " |-> " #cond) + +#define KAI_ASSUME_MSG KAI_ASSERT_MSG +#define KAI_ASSUME KAI_ASSERT +#define KAI_ASSUME_IF_MSG KAI_ASSERT_IF_MSG +#define KAI_ASSUME_IF KAI_ASSERT_IF + +#define KAI_UNUSED(x) (void)(x) diff --git a/test/common/data_format.cpp b/test/common/data_format.cpp new file mode 100644 index 00000000..1673b405 --- /dev/null +++ b/test/common/data_format.cpp @@ -0,0 +1,144 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/common/data_format.hpp" + +#include +#include + +#include "src/kai_common.h" +#include "test/common/data_type.hpp" +#include "test/reference/round.hpp" + +namespace kai::test { + +DataFormat::DataFormat( + DataType data_type, size_t block_height, size_t block_width, QuantizationFormat quant_format, DataType scale_dt, + DataType zero_point_dt, size_t subblock_height, size_t subblock_width) noexcept : + _data_type(data_type), + _quant_format(quant_format), + _scale_dt(scale_dt), + _zero_point_dt(zero_point_dt), + _block_height(block_height), + _block_width(block_width), + _subblock_height(subblock_height), + _subblock_width(subblock_width) { +} + +bool DataFormat::operator==(const DataFormat& rhs) const { + return _data_type == rhs._data_type && _quant_format == rhs._quant_format && _scale_dt == rhs._scale_dt && + _zero_point_dt == rhs._zero_point_dt && _block_height == rhs._block_height && _block_width == rhs._block_width; +} + +bool DataFormat::operator!=(const DataFormat& rhs) const { + return !(*this == rhs); +} + +DataType DataFormat::data_type() const { + return _data_type; +} + +DataFormat::QuantizationFormat DataFormat::quantization_format() const { + return _quant_format; +} + +DataType DataFormat::scale_data_type() const { + return _scale_dt; +} + +DataType DataFormat::zero_point_data_type() const { + return _zero_point_dt; +} + +bool DataFormat::is_raw() const { + return _quant_format == QuantizationFormat::NONE && // + _block_height == 0 && _block_width == 0 && _subblock_height == 0 && _subblock_width == 0; +} + +size_t DataFormat::block_height() const { + return _block_height; +} + +size_t DataFormat::block_width() const { + return _block_width; +} + +size_t DataFormat::subblock_height() const { + return _subblock_height; +} + +size_t DataFormat::subblock_width() const { + return _subblock_width; +} + +size_t DataFormat::scheduler_block_height([[maybe_unused]] size_t full_height) const { + switch (_quant_format) { + case QuantizationFormat::NONE: + return _block_height > 0 ? _block_height : 1; + + case QuantizationFormat::PER_ROW: + return _block_height; + + default: + KAI_ERROR("Unsupported quantization packing format!"); + } +} + +size_t DataFormat::scheduler_block_width(size_t full_width) const { + switch (_quant_format) { + case QuantizationFormat::NONE: + return _block_width > 0 ? _block_width : 1; + + case QuantizationFormat::PER_ROW: + return full_width; + + default: + KAI_ERROR("Unsupported quantization packing format!"); + } +} + +uintptr_t DataFormat::default_row_stride(size_t width) const { + const auto padded_width = round_up_multiple(width, _block_width > 0 ? _block_width : 1); + + switch (_quant_format) { + case QuantizationFormat::NONE: + return padded_width * data_type_size_in_bits(_data_type) / 8; + + case QuantizationFormat::PER_ROW: + return _block_height * data_type_size_in_bits(_zero_point_dt) / 8 + // + _block_height * padded_width * data_type_size_in_bits(_data_type) / 8 + // + _block_height * data_type_size_in_bits(_scale_dt) / 8; + + default: + KAI_ERROR("Unsupported quantization packing format!"); + } +} + +uintptr_t DataFormat::default_offset_in_bytes(size_t row, size_t col, size_t width) const { + const auto row_stride = default_row_stride(width); + + KAI_ASSERT(col % scheduler_block_width(width) == 0); + + switch (_quant_format) { + case QuantizationFormat::NONE: + return row * row_stride + col * data_type_size_in_bits(_data_type) / 8; + + case QuantizationFormat::PER_ROW: + KAI_ASSERT(row % _block_height == 0); + KAI_ASSERT(col == 0); + return (row / _block_height) * row_stride + col * data_type_size_in_bits(_data_type) / 8; + + default: + KAI_ERROR("Unsupported quantization packing format!"); + } +} + +size_t DataFormat::default_size_in_bytes(size_t height, size_t width) const { + const auto num_rows = _block_height > 0 ? (height + _block_height - 1) / _block_height : height; + return num_rows * default_row_stride(width); +} + +} // namespace kai::test diff --git a/test/common/data_format.hpp b/test/common/data_format.hpp new file mode 100644 index 00000000..3aa19f71 --- /dev/null +++ b/test/common/data_format.hpp @@ -0,0 +1,126 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#include "test/common/data_type.hpp" + +namespace kai::test { + +/// Data format. +class DataFormat { +public: + /// Quantization packing format. + enum class QuantizationFormat : uint32_t { + NONE, ///< No quantization information is included. + PER_ROW, ///< Per-row quantization. + }; + + /// Creates a new data format. + /// + /// @param[in] data_type Data type of data value. + /// @param[in] block_height Block height. + /// @param[in] block_width Block width. + /// @param[in] quant_format Quantization packing format. + /// @param[in] scale_dt Data type of scale value. + /// @param[in] zero_point_dt Data type of zero point value. + /// @param[in] subblock_height Sub-block height. + /// @param[in] subblock_width Sub-block width. + DataFormat( + DataType data_type, size_t block_height = 0, size_t block_width = 0, + QuantizationFormat quant_format = QuantizationFormat::NONE, DataType scale_dt = DataType::UNKNOWN, + DataType zero_point_dt = DataType::UNKNOWN, size_t subblock_height = 0, size_t subblock_width = 0) noexcept; + + /// Equality operator. + [[nodiscard]] bool operator==(const DataFormat& rhs) const; + + /// Unequality operator. + [[nodiscard]] bool operator!=(const DataFormat& rhs) const; + + /// Gets the quantization packing format. + [[nodiscard]] QuantizationFormat quantization_format() const; + + /// Gets the data type of data value. + [[nodiscard]] DataType data_type() const; + + /// Gets the data type of scale value. + [[nodiscard]] DataType scale_data_type() const; + + /// Gets the data type of zero point value. + [[nodiscard]] DataType zero_point_data_type() const; + + /// Gets a value indicating whether this format has no blocking or packed quantization information. + [[nodiscard]] bool is_raw() const; + + /// Gets the block height. + [[nodiscard]] size_t block_height() const; + + /// Gets the block width. + [[nodiscard]] size_t block_width() const; + + /// Gets the sub-block height. + [[nodiscard]] size_t subblock_height() const; + + /// Gets the sub-block width. + [[nodiscard]] size_t subblock_width() const; + + /// Gets the scheduling block height. + /// + /// @param[in] full_height Height of the full matrix. + /// + /// @return The block height for scheduling purpose. + [[nodiscard]] size_t scheduler_block_height(size_t full_height) const; + + /// Gets the scheduling block width. + /// + /// @param[in] full_width Width of the full matrix. + /// + /// @return The block width for scheduling purpose. + [[nodiscard]] size_t scheduler_block_width(size_t full_width) const; + + /// Gets the row stride in bytes given the data is stored continuously without any gap in the memory. + /// + /// In case of per-row quantization, the row stride is the number of bytes from one row group + /// to the next. One row group consists of `block_height` rows. + /// + /// @param[in] width Width of the full matrix. + /// + /// @return The default row stride in bytes of the matrix. + [[nodiscard]] uintptr_t default_row_stride(size_t width) const; + + /// Gets the offsets in bytes in the data buffer given the data is stored continuously + /// without any gap in the memory. + /// + /// @param[in] row Row coordinate. + /// @param[in] col Colum coordinate. + /// @param[in] width Width of the full matrix. + /// + /// @return The default offset in bytes. + [[nodiscard]] uintptr_t default_offset_in_bytes(size_t row, size_t col, size_t width) const; + + /// Gets the size in bytes of the matrix given the data is stored continuously without any gap in the memory. + /// + /// @param[in] height Height of the full matrix. + /// @param[in] width Width of the full matrix. + /// + /// @return The size in bytes of the matrix. + [[nodiscard]] size_t default_size_in_bytes(size_t height, size_t width) const; + +private: + DataType _data_type; + QuantizationFormat _quant_format; + DataType _scale_dt; + DataType _zero_point_dt; + size_t _block_height; + size_t _block_width; + size_t _subblock_height; + size_t _subblock_width; +}; + +} // namespace kai::test diff --git a/test/common/data_type.cpp b/test/common/data_type.cpp new file mode 100644 index 00000000..79cc7e7c --- /dev/null +++ b/test/common/data_type.cpp @@ -0,0 +1,81 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/common/data_type.hpp" + +#include +#include + +#include "src/kai_common.h" + +namespace kai::test { + +namespace { + +bool has_i(DataType dt) { + return (static_cast(dt) & (1 << 15)) != 0; +} + +bool has_s(DataType dt) { + return (static_cast(dt) & (1 << 14)) != 0; +} + +bool has_q(DataType dt) { + return (static_cast(dt) & (1 << 13)) != 0; +} + +bool has_a(DataType dt) { + return (static_cast(dt) & (1 << 12)) != 0; +} + +size_t bits(DataType dt) { + return static_cast(dt) & 0xFF; +} + +} // namespace + +size_t data_type_size_in_bits(DataType dt) { + return bits(dt); +} + +bool data_type_is_integral(DataType dt) { + return has_i(dt); +} + +bool data_type_is_float(DataType dt) { + KAI_ASSERT(data_type_is_signed(dt)); + return !data_type_is_integral(dt); +} + +bool data_type_is_float_fp(DataType dt) { + KAI_ASSERT(data_type_is_float(dt)); + return !has_q(dt); +} + +bool data_type_is_float_bf(DataType dt) { + KAI_ASSERT(data_type_is_float(dt)); + return has_q(dt); +} + +bool data_type_is_signed(DataType dt) { + if (!has_s(dt)) { + KAI_ASSERT(data_type_is_integral(dt)); + } + + return has_s(dt); +} + +bool data_type_is_quantized(DataType dt) { + KAI_ASSERT_IF(has_q(dt), data_type_is_integral(dt)); + return has_q(dt); +} + +bool data_type_is_quantized_asymm(DataType dt) { + KAI_ASSERT_IF(has_a(dt), data_type_is_quantized(dt)); + return has_a(dt); +} + +} // namespace kai::test diff --git a/test/common/data_type.hpp b/test/common/data_type.hpp new file mode 100644 index 00000000..8fbdcf9d --- /dev/null +++ b/test/common/data_type.hpp @@ -0,0 +1,110 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +namespace kai::test { + +/// Data type. +enum class DataType : uint16_t { + // Encoding: + // + // 15 0 + // +---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+ + // | i | s | q | a | RES0 | bits | + // +---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+ + // + // (RES0: reserved, filled with 0s) + // + // Fields: + // + // * i: integer (1) or floating-point (0). + // * s: signed (1) or unsigned (0). + // * q: + // - Integer (i): quantized (1) or non-quantized (0). + // - Floating-point (!i): brain (1) or binary (0). + // * a: + // - Quantized (i && q): asymmetric (1) or symmetric (0). + // - Otherwise: RES0. + // * bits: size in bits. + + UNKNOWN = 0, ///< No data. + + FP32 = 0b0'1'0'0'0000'00100000, ///< Single-precision floating-point. + FP16 = 0b0'1'0'0'0000'00010000, ///< Half-precision floating-point. + + I32 = 0b1'1'0'0'0000'00100000, ///< 32-bit signed integer. + + QI8 = 0b1'1'1'1'0000'00001000, ///< 8-bit asymmetric quantized. + + QSU4 = 0b1'0'1'0'0000'00000100, ///< 4-bit unsigned symmetric quantized. + QSI4 = 0b1'1'1'0'0000'00000100, ///< 4-bit signed symmetric quantized. +}; + +/// Gets the size in bits of the specified data type. +/// +/// @param[in] dt The data type. +/// +/// @return The size in bits. +[[nodiscard]] size_t data_type_size_in_bits(DataType dt); + +/// Gets a value indicating whether the data type is integral. +/// +/// @param[in] dt The data type. +/// +/// @return `true` if the data type is integral. +[[nodiscard]] bool data_type_is_integral(DataType dt); + +/// Gets a value indicating whether the data type is floating-point. +/// +/// @param[in] dt The data type. +/// +/// @return `true` if the data type is floating-point. +[[nodiscard]] bool data_type_is_float(DataType dt); + +/// Gets a value indicating whether the data type is binary floating-point. +/// +/// Binary floating point are `half`, `float`, `double`. +/// +/// @param[in] dt The data type. +/// +/// @return `true` if the data type is binary floating-point. +[[nodiscard]] bool data_type_is_float_fp(DataType dt); + +/// Gets a value indicating whether the data type is brain floating-point. +/// +/// Binary floating point are `bfloat16`. +/// +/// @param[in] dt The data type. +/// +/// @return `true` if the data type is brain floating-point. +[[nodiscard]] bool data_type_is_float_bf(DataType dt); + +/// Gets a value indicating whether the data type is signed. +/// +/// @param[in] dt The data type. +/// +/// @return `true` if the data type is signed. +[[nodiscard]] bool data_type_is_signed(DataType dt); + +/// Gets a value indicating whether the data type is quantized. +/// +/// @param[in] dt The data type. +/// +/// @return `true` if the data type is quantized. +[[nodiscard]] bool data_type_is_quantized(DataType dt); + +/// Gets a value indicating whether the data type is asymmetric quantized. +/// +/// @param[in] dt The data type. +/// +/// @return `true` if the data type is asymmetric quantized. +[[nodiscard]] bool data_type_is_quantized_asymm(DataType dt); + +} // namespace kai::test diff --git a/test/common/int4.cpp b/test/common/int4.cpp new file mode 100644 index 00000000..bc9563b1 --- /dev/null +++ b/test/common/int4.cpp @@ -0,0 +1,98 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/common/int4.hpp" + +#include +#include + +namespace kai::test { + +UInt4& UInt4::operator=(uint8_t value) { + _value = value; + return *this; +} + +UInt4::operator int32_t() const { + return _value; +} + +UInt4::operator float() const { + return _value; +} + +UInt4 UInt4::operator+(UInt4 rhs) const { + return UInt4(_value + rhs._value); +} + +UInt4 UInt4::operator-(UInt4 rhs) const { + return UInt4(_value - rhs._value); +} + +UInt4 UInt4::operator*(UInt4 rhs) const { + return UInt4(_value * rhs._value); +} + +UInt4 UInt4::operator/(UInt4 rhs) const { + return UInt4(_value / rhs._value); +} + +uint8_t UInt4::pack_u8(UInt4 low, UInt4 high) { + return (low._value & 0x0F) | (high._value << 4); +} + +std::tuple UInt4::unpack_u8(uint8_t value) { + const uint8_t low = value & 0x0F; + const uint8_t high = value >> 4; + + return {UInt4(low), UInt4(high)}; +} + +// ===================================================================================================================== + +Int4& Int4::operator=(int8_t value) { + _value = value; + return *this; +} + +Int4::operator int32_t() const { + return _value; +} + +Int4::operator float() const { + return _value; +} + +Int4 Int4::operator+(Int4 rhs) const { + return Int4(static_cast(_value + rhs._value)); +} + +Int4 Int4::operator-(Int4 rhs) const { + return Int4(static_cast(_value - rhs._value)); +} + +Int4 Int4::operator*(Int4 rhs) const { + return Int4(static_cast(_value * rhs._value)); +} + +Int4 Int4::operator/(Int4 rhs) const { + return Int4(static_cast(_value / rhs._value)); +} + +uint8_t Int4::pack_u8(Int4 low, Int4 high) { + return (low._value & 0x0F) | (high._value << 4); +} + +std::tuple Int4::unpack_u8(uint8_t value) { + // NOLINTBEGIN(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + const int8_t low = static_cast(value << 4) >> 4; + const int8_t high = static_cast(value) >> 4; + // NOLINTEND(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + + return {Int4(low), Int4(high)}; +} + +} // namespace kai::test diff --git a/test/common/int4.hpp b/test/common/int4.hpp new file mode 100644 index 00000000..169c5817 --- /dev/null +++ b/test/common/int4.hpp @@ -0,0 +1,112 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +namespace kai::test { + +/// 4-bit unsigned integer. +class UInt4 { +public: + /// Creates a new 4-bit unsigned integer value. + /// + /// @param[in] value Value. + constexpr explicit UInt4(uint8_t value) : _value(value) { + } + + /// Assignment operator. + UInt4& operator=(uint8_t value); + + /// Conversion operator. + operator int32_t() const; + + /// Conversion operator. + operator float() const; + + /// Addition operator. + [[nodiscard]] UInt4 operator+(UInt4 rhs) const; + + /// Subtraction operator. + [[nodiscard]] UInt4 operator-(UInt4 rhs) const; + + /// Multiplication operator. + [[nodiscard]] UInt4 operator*(UInt4 rhs) const; + + /// Division operator. + [[nodiscard]] UInt4 operator/(UInt4 rhs) const; + + /// Packs two 4-bit unsigned integer values into one byte. + /// + /// @param[in] low Low nibble. + /// @param[in] high High nibble. + /// + /// @return The packed byte. + [[nodiscard]] static uint8_t pack_u8(UInt4 low, UInt4 high); + + /// Unpacks one byte to two 4-bit unsigned integer values. + /// + /// @param[in] value 8-bit packed value. + /// + /// @return The low and high nibbles. + [[nodiscard]] static std::tuple unpack_u8(uint8_t value); + +private: + uint8_t _value; +}; + +/// 4-bit signed integer. +class Int4 { +public: + /// Creates a new 4-bit signed integer value. + /// + /// @param[in] value Value. + constexpr explicit Int4(int8_t value) : _value(value) { + } + + /// Assignment operator. + Int4& operator=(int8_t value); + + /// Conversion operator. + explicit operator int32_t() const; + + /// Conversion operator. + explicit operator float() const; + + /// Addition operator. + [[nodiscard]] Int4 operator+(Int4 rhs) const; + + /// Subtraction operator. + [[nodiscard]] Int4 operator-(Int4 rhs) const; + + /// Multiplication operator. + [[nodiscard]] Int4 operator*(Int4 rhs) const; + + /// Division operator. + [[nodiscard]] Int4 operator/(Int4 rhs) const; + + /// Packs two 4-bit signed integer values into one byte. + /// + /// @param[in] low Low nibble. + /// @param[in] high High nibble. + /// + /// @return The packed byte. + [[nodiscard]] static uint8_t pack_u8(Int4 low, Int4 high); + + /// Unpacks one byte to two 4-bit signed integer values. + /// + /// @param[in] value 8-bit packed value. + /// + /// @return The low and high nibbles. + [[nodiscard]] static std::tuple unpack_u8(uint8_t value); + +private: + int8_t _value; +}; + +} // namespace kai::test diff --git a/test/common/logging.hpp b/test/common/logging.hpp new file mode 100644 index 00000000..ff6130a0 --- /dev/null +++ b/test/common/logging.hpp @@ -0,0 +1,68 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include + +#include "test/common/int4.hpp" + +#define KAI_LOGE(...) kai::test::detail::log("ERROR", __VA_ARGS__) + +namespace kai::test::detail { + +/// Prints the specified value to standard error. +/// +/// @tparam T Data type. +/// +/// @param[in] value Value to be printed out. +template +void write_log_content(T&& value) { + using TT = std::decay_t; + + if constexpr (std::is_same_v) { + std::cerr << static_cast(value); + } else if constexpr (std::is_same_v) { + std::cerr << static_cast(value); + } else if constexpr (std::is_same_v) { + std::cerr << static_cast(value); + } else if constexpr (std::is_same_v) { + std::cerr << static_cast(value); + } else { + std::cerr << value; + } +} + +/// Prints the specified values to standard error. +/// +/// @tparam T Data type of the first value. +/// @tparam Ts Data types of the subsequent values. +/// +/// @param[in] value First value to be printed out. +/// @param[in] others Subsequent values to be printed out. +template +void write_log_content(T&& value, Ts&&... others) { + write_log_content(std::forward(value)); + write_log_content(std::forward(others)...); +} + +/// Prints the log to standard error. +/// +/// @tparam Ts Data types of values to be printed out. +/// +/// @param[in] level Severity level. +/// @param[in] args Values to be printed out. +template +void log(std::string_view level, Ts&&... args) { + std::cerr << level << " | "; + write_log_content(std::forward(args)...); + std::cerr << "\n"; +} + +} // namespace kai::test::detail diff --git a/test/common/memory.hpp b/test/common/memory.hpp new file mode 100644 index 00000000..bf5fbb01 --- /dev/null +++ b/test/common/memory.hpp @@ -0,0 +1,77 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#include "test/common/int4.hpp" + +namespace kai::test { + +/// The size in bits of type `T`. +template +inline constexpr size_t size_in_bits = sizeof(T) * 8; + +/// The size in bits of type `T`. +template <> +inline constexpr size_t size_in_bits = 4; + +/// The size in bits of type `T`. +template <> +inline constexpr size_t size_in_bits = 4; + +/// Reads the array at the specified index. +/// +/// @param[in] array Data buffer. +/// @param[in] index Array index. +/// +/// @return The array value at the specified index. +template +T read_array(const void* array, size_t index) { + if constexpr (std::is_same_v) { + const auto [lo, hi] = UInt4::unpack_u8(reinterpret_cast(array)[index / 2]); + return index % 2 == 0 ? lo : hi; + } else if constexpr (std::is_same_v) { + const auto [lo, hi] = Int4::unpack_u8(reinterpret_cast(array)[index / 2]); + return index % 2 == 0 ? lo : hi; + } else { + return reinterpret_cast(array)[index]; + } +} + +/// Writes the specified value to the array. +/// +/// @param[in] array Data buffer. +/// @param[in] index Array index. +/// @param[in] value Value to be stored. +template +void write_array(void* array, size_t index, T value) { + if constexpr (std::is_same_v) { + auto* arr_value = reinterpret_cast(array) + index / 2; + const auto [lo, hi] = UInt4::unpack_u8(*arr_value); + + if (index % 2 == 0) { + *arr_value = UInt4::pack_u8(value, hi); + } else { + *arr_value = UInt4::pack_u8(lo, value); + } + } else if constexpr (std::is_same_v) { + auto* arr_value = reinterpret_cast(array) + index / 2; + const auto [lo, hi] = Int4::unpack_u8(*arr_value); + + if (index % 2 == 0) { + *arr_value = Int4::pack_u8(value, hi); + } else { + *arr_value = Int4::pack_u8(lo, value); + } + } else { + reinterpret_cast(array)[index] = value; + } +} + +} // namespace kai::test diff --git a/test/common/numeric_limits.hpp b/test/common/numeric_limits.hpp new file mode 100644 index 00000000..e5950810 --- /dev/null +++ b/test/common/numeric_limits.hpp @@ -0,0 +1,39 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#include "test/common/int4.hpp" + +namespace kai::test { + +/// Highest finite value of type `T`. +template +inline constexpr T numeric_highest = std::numeric_limits::max(); + +/// Highest finite value of type `T`. +template <> +inline constexpr UInt4 numeric_highest{15}; + +/// Highest finite value of type `T`. +template <> +inline constexpr Int4 numeric_highest{7}; + +/// Lowest finite value of type `T`. +template +inline constexpr T numeric_lowest = std::numeric_limits::lowest(); + +/// Lowest finite value of type `T`. +template <> +inline constexpr UInt4 numeric_lowest{0}; + +/// Lowest finite value of type `T`. +template <> +inline constexpr Int4 numeric_lowest{-8}; + +} // namespace kai::test diff --git a/test/common/printer.cpp b/test/common/printer.cpp new file mode 100644 index 00000000..368ae7db --- /dev/null +++ b/test/common/printer.cpp @@ -0,0 +1,112 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include + +#include "src/kai_common.h" +#include "test/common/data_format.hpp" +#include "test/common/data_type.hpp" +#include "test/common/int4.hpp" + +namespace kai::test { + +namespace { + +inline void print_data(std::ostream& os, const uint8_t* data, size_t len, DataType data_type) { + if (data_type == DataType::QSU4) { + for (size_t i = 0; i < len / 2; ++i) { + const auto [low, high] = UInt4::unpack_u8(data[i]); + os << static_cast(low) << ", " << static_cast(high) << ", "; + } + } else if (data_type == DataType::QSI4) { + for (size_t i = 0; i < len / 2; ++i) { + const auto [low, high] = Int4::unpack_u8(data[i]); + os << static_cast(low) << ", " << static_cast(high) << ", "; + } + } else { + for (size_t i = 0; i < len; ++i) { + switch (data_type) { + case DataType::FP32: + os << reinterpret_cast(data)[i]; + break; + + case DataType::I32: + os << reinterpret_cast(data)[i]; + break; + + case DataType::QI8: + os << static_cast(reinterpret_cast(data)[i]); + break; + + default: + KAI_ERROR("Unsupported data type!"); + } + + os << ", "; + } + } +} + +void print_matrix_raw(std::ostream& os, const uint8_t* data, DataType data_type, size_t height, size_t width) { + const auto row_stride = width * data_type_size_in_bits(data_type) / 8; + + os << "[\n"; + for (size_t y = 0; y < height; ++y) { + os << " ["; + print_data(os, data + y * row_stride, width, data_type); + os << "],\n"; + } + os << "]\n"; +} + +void print_matrix_per_row( + std::ostream& os, const uint8_t* data, const DataFormat& format, size_t height, size_t width) { + const auto block_height = format.block_height(); + const auto num_blocks = (height + block_height - 1) / block_height; + + const auto block_data_bytes = block_height * width * data_type_size_in_bits(format.data_type()) / 8; + const auto block_offsets_bytes = block_height * data_type_size_in_bits(format.zero_point_data_type()) / 8; + const auto block_scales_bytes = block_height * data_type_size_in_bits(format.scale_data_type()) / 8; + + os << "[\n"; + for (size_t y = 0; y < num_blocks; ++y) { + os << " {\"offsets\": ["; + print_data(os, data, block_height, format.zero_point_data_type()); + os << "], \"data\": ["; + print_data(os, data + block_offsets_bytes, block_height * width, format.data_type()); + os << "], \"scales\": ["; + print_data(os, data + block_offsets_bytes + block_data_bytes, block_height, format.scale_data_type()); + os << "]},\n"; + + data += block_offsets_bytes + block_data_bytes + block_scales_bytes; + } + os << "]\n"; +} + +} // namespace + +void print_matrix( + std::ostream& os, std::string_view name, const void* data, const DataFormat& format, size_t height, size_t width) { + os << name << " = "; + + switch (format.quantization_format()) { + case DataFormat::QuantizationFormat::NONE: + print_matrix_raw(os, reinterpret_cast(data), format.data_type(), height, width); + break; + + case DataFormat::QuantizationFormat::PER_ROW: + print_matrix_per_row(os, reinterpret_cast(data), format, height, width); + break; + + default: + KAI_ERROR("Unsupported quantization packing format!"); + } +} + +} // namespace kai::test diff --git a/test/common/printer.hpp b/test/common/printer.hpp new file mode 100644 index 00000000..ec6213ae --- /dev/null +++ b/test/common/printer.hpp @@ -0,0 +1,28 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include + +namespace kai::test { + +class DataFormat; + +/// Prints the matrix data to the output stream. +/// +/// @param[in] os Output stream to write the data to. +/// @param[in] name Matrix name. +/// @param[in] data Data buffer. +/// @param[in] format Data format. +/// @param[in] height Number of rows. +/// @param[in] width Number of columns. +void print_matrix( + std::ostream& os, std::string_view name, const void* data, const DataFormat& format, size_t height, size_t width); + +} // namespace kai::test diff --git a/test/common/type_traits.hpp b/test/common/type_traits.hpp new file mode 100644 index 00000000..2267432a --- /dev/null +++ b/test/common/type_traits.hpp @@ -0,0 +1,79 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +namespace kai::test { + +class UInt4; +class Int4; + +/// `true` if `T` is unsigned numeric type. +template +inline constexpr bool is_unsigned = std::is_unsigned_v; + +/// `true` if `T` is unsigned numeric type. +template <> +inline constexpr bool is_unsigned = true; + +/// `true` if `T` is unsigned numeric type. +template <> +inline constexpr bool is_unsigned = true; + +/// `true` if `T` is signed numeric type. +template +inline constexpr bool is_signed = std::is_signed_v; + +/// `true` if `T` is signed numeric type. +template <> +inline constexpr bool is_signed = false; + +/// `true` if `T` is signed numeric type. +template <> +inline constexpr bool is_signed = false; + +/// `true` if `T` is integral numeric type. +template +inline constexpr bool is_integral = std::is_integral_v; + +/// `true` if `T` is integral numeric type. +template <> +inline constexpr bool is_integral = true; + +/// `true` if `T` is integral numeric type. +template <> +inline constexpr bool is_integral = true; + +/// `true` if `T` is floating-point type. +template +inline constexpr bool is_floating_point = std::is_floating_point_v; + +/// Signed version of type `T`. +template +struct make_signed { + using type = std::make_signed_t; +}; + +/// Signed version of type `T`. +template <> +struct make_signed { + using type = Int4; +}; + +/// Signed version of type `T`. +template <> +struct make_signed { + using type = Int4; +}; + +/// Signed version of type `T`. +template +using make_signed_t = typename make_signed::type; + +} // namespace kai::test diff --git a/test/reference/round.cpp b/test/reference/round.cpp new file mode 100644 index 00000000..52a2c557 --- /dev/null +++ b/test/reference/round.cpp @@ -0,0 +1,36 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/reference/round.hpp" + +#include +#include + +namespace kai::test { + +int32_t round_to_nearest_even_i32(float value) { + int32_t rounded = 0; + asm("fcvtns %w[output], %s[input]" : [output] "=r"(rounded) : [input] "w"(value)); + return rounded; +} + +size_t round_to_nearest_even_usize(float value) { + static_assert(sizeof(size_t) == sizeof(uint64_t)); + + uint64_t rounded = 0; + asm("fcvtns %x[output], %s[input]" : [output] "=r"(rounded) : [input] "w"(value)); + return rounded; +} + +size_t round_up_multiple(size_t a, size_t b) { + return ((a + b - 1) / b) * b; +} + +size_t round_down_multiple(size_t a, size_t b) { + return (a / b) * b; +} + +} // namespace kai::test diff --git a/test/reference/round.hpp b/test/reference/round.hpp new file mode 100644 index 00000000..7503936f --- /dev/null +++ b/test/reference/round.hpp @@ -0,0 +1,62 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +namespace kai::test { + +/// Rounds the specified value to nearest with tie to even. +/// +/// For example: +/// +/// * 0.4 is rounded to 0. +/// * 0.5 is rounded to 0 (as 0 is the nearest even value). +/// * 0.6 is rounded to 1. +/// * 1.4 is rounded to 1. +/// * 1.5 is rounded to 2 (as 2 is the nearest even value). +/// * 1.6 is rounded to 2. +/// +/// @param[in] value Value to be rounded. +/// +/// @return The rounded value. +int32_t round_to_nearest_even_i32(float value); + +/// Rounds the specified value to nearest with tie to even. +/// +/// For example: +/// +/// * 0.4 is rounded to 0. +/// * 0.5 is rounded to 0 (as 0 is the nearest even value). +/// * 0.6 is rounded to 1. +/// * 1.4 is rounded to 1. +/// * 1.5 is rounded to 2 (as 2 is the nearest even value). +/// * 1.6 is rounded to 2. +/// +/// @param[in] value Value to be rounded. +/// +/// @return The rounded value. +size_t round_to_nearest_even_usize(float value); + +/// Rounds up the input value to the multiple of the unit value. +/// +/// @param[in] a Input value. +/// @param[in] b Unit value. +/// +/// @return The rounded value. +size_t round_up_multiple(size_t a, size_t b); + +/// Rounds down the input value to the multiple of the unit value. +/// +/// @param[in] a Input value. +/// @param[in] b Unit value. +/// +/// @return The rounded value. +size_t round_down_multiple(size_t a, size_t b); + +} // namespace kai::test -- GitLab From fed53b99e792afd687b8f6d11a8edcfe7e8cfa14 Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Tue, 30 Apr 2024 13:26:23 +0100 Subject: [PATCH 2/5] Add matrix portion and comparison * Add matrix portion and rectangular region to specify a sub-matrix from a full matrix. * Add matrix comparison and mismatch handler to compare two matrices. - Support both raw and per-row quantization formats. Signed-off-by: Viet-Hoa Do --- CMakeLists.txt | 3 + test/common/compare.cpp | 270 +++++++++++++++++++++++++++++++++ test/common/compare.hpp | 125 +++++++++++++++ test/common/matrix_portion.cpp | 64 ++++++++ test/common/matrix_portion.hpp | 68 +++++++++ test/common/rect.cpp | 41 +++++ test/common/rect.hpp | 51 +++++++ 7 files changed, 622 insertions(+) create mode 100644 test/common/compare.cpp create mode 100644 test/common/compare.hpp create mode 100644 test/common/matrix_portion.cpp create mode 100644 test/common/matrix_portion.hpp create mode 100644 test/common/rect.cpp create mode 100644 test/common/rect.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index b6d9662c..5266a3a5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -60,6 +60,9 @@ if(KLEIDIAI_BUILD_TESTS) test/common/data_format.cpp test/common/printer.cpp test/common/int4.cpp + test/common/compare.cpp + test/common/matrix_portion.cpp + test/common/rect.cpp test/reference/round.cpp ) diff --git a/test/common/compare.cpp b/test/common/compare.cpp new file mode 100644 index 00000000..206ec0e5 --- /dev/null +++ b/test/common/compare.cpp @@ -0,0 +1,270 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/common/compare.hpp" + +#include +#include +#include +#include + +#include "src/kai_common.h" +#include "test/common/data_format.hpp" +#include "test/common/data_type.hpp" +#include "test/common/int4.hpp" +#include "test/common/logging.hpp" +#include "test/common/memory.hpp" +#include "test/common/printer.hpp" +#include "test/common/rect.hpp" + +namespace kai::test { + +namespace { + +/// Calculates the absolute and relative errors. +/// +/// @param[in] imp Value under test. +/// @param[in] ref Reference value. +/// +/// @return The absolute error and relative error. +template +std::tuple calculate_error(T imp, T ref) { + const auto imp_f = static_cast(imp); + const auto ref_f = static_cast(ref); + + const auto abs_error = std::abs(imp_f - ref_f); + const auto rel_error = ref_f != 0 ? abs_error / std::abs(ref_f) : 0.0F; + + return {abs_error, rel_error}; +} + +/// Compares matrices with per-row quantization. +template +bool compare_raw( + const void* imp_data, const void* ref_data, size_t full_height, size_t full_width, const Rect& rect, + MismatchHandler& handler) { + for (size_t y = 0; y < full_height; ++y) { + for (size_t x = 0; x < full_width; ++x) { + const auto in_roi = + y >= rect.start_row() && y < rect.end_row() && x >= rect.start_col() && x < rect.end_col(); + + const auto imp_value = read_array(imp_data, y * full_width + x); + const auto ref_value = in_roi ? read_array(ref_data, y * full_width + x) : 0; + + const auto [abs_err, rel_err] = calculate_error(imp_value, ref_value); + + if (abs_err != 0 || rel_err != 0) { + const auto notifying = !in_roi || handler.handle_data(abs_err, rel_err); + + if (notifying) { + KAI_LOGE("Mismatched data at (", y, ", ", x, "): actual = ", imp_value, ", expected: ", ref_value); + } + } + } + } + + return handler.success(full_height * full_width); +} + +/// Compares matrices with per-row quantization. +template +bool compare_per_row( + const void* imp_data, const void* ref_data, const DataFormat& format, size_t full_height, size_t full_width, + const Rect& rect, MismatchHandler& handler) { + const auto block_height = format.block_height(); + const auto block_width = format.block_width(); + + KAI_ASSUME(format.scheduler_block_height(full_height) == block_height); + KAI_ASSUME(format.scheduler_block_width(full_width) == full_width); + KAI_ASSUME(rect.start_col() == 0); + KAI_ASSUME(rect.width() == full_width); + + const auto data_bits = size_in_bits; + + const auto num_groups = (full_height + block_height - 1) / block_height; + const auto group_num_blocks = (full_width + block_width - 1) / block_width; + + const auto group_offsets_bytes = block_height * sizeof(Offset); + const auto group_scales_bytes = block_height * sizeof(Scale); + const auto block_data_bytes = block_height * block_width * data_bits / 8; + + const auto begin_group = rect.start_row() / block_height; + const auto end_group = rect.end_row() / block_height; + + const auto* imp_ptr = reinterpret_cast(imp_data); + const auto* ref_ptr = reinterpret_cast(ref_data); + + for (size_t group_no = 0; group_no < num_groups; ++group_no) { + const auto in_roi = group_no >= begin_group && group_no < end_group; + + // Checks the quantization offsets. + for (size_t i = 0; i < block_height; ++i) { + const auto imp_offset = reinterpret_cast(imp_ptr)[i]; + const Offset ref_offset = in_roi ? reinterpret_cast(ref_ptr)[i] : 0; + const auto [abs_err, rel_err] = calculate_error(imp_offset, ref_offset); + + if (abs_err != 0 || rel_err != 0) { + handler.mark_as_failed(); + + const auto raw_row = group_no * block_height + i; + KAI_LOGE( + "Mismatched quantization offset ", raw_row, ": actual = ", imp_offset, ", expected: ", ref_offset); + } + } + + imp_ptr += group_offsets_bytes; + ref_ptr += group_offsets_bytes; + + // Checks the data. + for (size_t block_no = 0; block_no < group_num_blocks; ++block_no) { + for (size_t y = 0; y < block_height; ++y) { + for (size_t x = 0; x < block_width; ++x) { + const auto imp_data = read_array(imp_ptr, y * block_width + x); + const Data ref_data = in_roi ? read_array(ref_ptr, y * block_width + x) : Data(0); + const auto [abs_err, rel_err] = calculate_error(imp_data, ref_data); + + if (abs_err != 0 || rel_err != 0) { + const auto notifying = !in_roi || handler.handle_data(abs_err, rel_err); + + if (notifying) { + const auto raw_row = group_no * block_height + y; + const auto raw_col = block_no * block_width + x; + + KAI_LOGE( + "Mismatched data at (", raw_row, ", ", raw_col, "): actual = ", imp_data, + ", expected: ", ref_data); + } + } + } + } + + imp_ptr += block_data_bytes; + ref_ptr += block_data_bytes; + } + + // Checks the quantization scales. + for (size_t i = 0; i < block_height; ++i) { + const auto imp_scale = reinterpret_cast(imp_ptr)[i]; + const Scale ref_scale = in_roi ? reinterpret_cast(ref_ptr)[i] : 0; + const auto [abs_err, rel_err] = calculate_error(imp_scale, ref_scale); + + if (abs_err != 0 || rel_err != 0) { + handler.mark_as_failed(); + + const auto raw_row = group_no * block_height + i; + KAI_LOGE( + "Mismatched quantization scale ", raw_row, ": actual = ", imp_scale, ", expected: ", ref_scale); + } + } + + imp_ptr += group_scales_bytes; + ref_ptr += group_scales_bytes; + } + + return handler.success(rect.height() * full_width); +} + +} // namespace + +bool compare( + const void* imp_data, const void* ref_data, const DataFormat& format, size_t full_height, size_t full_width, + const Rect& rect, MismatchHandler& handler) { + const auto data_type = format.data_type(); + const auto scale_dt = format.scale_data_type(); + const auto offset_dt = format.zero_point_data_type(); + + switch (format.quantization_format()) { + case DataFormat::QuantizationFormat::NONE: + switch (data_type) { + case DataType::FP32: + return compare_raw(imp_data, ref_data, full_height, full_width, rect, handler); + + default: + break; + } + + case DataFormat::QuantizationFormat::PER_ROW: + if (data_type == DataType::QI8 && scale_dt == DataType::FP32 && offset_dt == DataType::I32) { + return compare_per_row( + imp_data, ref_data, format, full_height, full_width, rect, handler); + } else if (data_type == DataType::QSI4 && scale_dt == DataType::FP32 && offset_dt == DataType::I32) { + return compare_per_row( + imp_data, ref_data, format, full_height, full_width, rect, handler); + } + + break; + + default: + break; + } + + KAI_ERROR("Unsupported format!"); +} + +// ===================================================================================================================== + +DefaultMismatchHandler::DefaultMismatchHandler( + float abs_error_threshold, float rel_error_threshold, size_t abs_mismatched_threshold, + float rel_mismatched_threshold) : + _abs_error_threshold(abs_error_threshold), + _rel_error_threshold(rel_error_threshold), + _abs_mismatched_threshold(abs_mismatched_threshold), + _rel_mismatched_threshold(rel_mismatched_threshold), + _num_mismatches(0), + _failed(false) { +} + +DefaultMismatchHandler::DefaultMismatchHandler(const DefaultMismatchHandler& rhs) : + _abs_error_threshold(rhs._abs_error_threshold), + _rel_error_threshold(rhs._rel_error_threshold), + _abs_mismatched_threshold(rhs._abs_mismatched_threshold), + _rel_mismatched_threshold(rhs._rel_mismatched_threshold), + _num_mismatches(0), + _failed(false) { + // Cannot copy mismatch handler that is already in use. + KAI_ASSUME(rhs._num_mismatches == 0); + KAI_ASSUME(!rhs._failed); +} + +DefaultMismatchHandler& DefaultMismatchHandler::operator=(const DefaultMismatchHandler& rhs) { + if (this != &rhs) { + // Cannot copy mismatch handler that is already in use. + KAI_ASSUME(rhs._num_mismatches == 0); + KAI_ASSUME(!rhs._failed); + + _abs_error_threshold = rhs._abs_error_threshold; + _rel_error_threshold = rhs._rel_error_threshold; + _abs_mismatched_threshold = rhs._abs_mismatched_threshold; + _rel_mismatched_threshold = rhs._rel_mismatched_threshold; + } + + return *this; +} + +bool DefaultMismatchHandler::handle_data(float absolute_error, float relative_error) { + const auto mismatched = absolute_error > _abs_error_threshold && relative_error > _rel_error_threshold; + + if (mismatched) { + ++_num_mismatches; + } + + return mismatched; +} + +void DefaultMismatchHandler::mark_as_failed() { + _failed = true; +} + +bool DefaultMismatchHandler::success(size_t num_checks) const { + if (_failed) { + return false; + } + + const auto mismatched_rate = static_cast(_num_mismatches) / static_cast(num_checks); + return _num_mismatches <= _abs_mismatched_threshold || mismatched_rate <= _rel_mismatched_threshold; +} + +} // namespace kai::test diff --git a/test/common/compare.hpp b/test/common/compare.hpp new file mode 100644 index 00000000..3bca0a5c --- /dev/null +++ b/test/common/compare.hpp @@ -0,0 +1,125 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +namespace kai::test { + +class DataFormat; +class Rect; +class MismatchHandler; + +/// Compares two matrices to check whether they are matched. +/// +/// @param[in] imp_data Data buffer of the actual implementation matrix. +/// @param[in] ref_data Data buffer of the reference implementation matrix. +/// @param[in] format Data format. +/// @param[in] full_height Height of the full matrix. +/// @param[in] full_width Width of the full matrix. +/// @param[in] rect Rectangular region of the matrix that is populated with data. +/// @param[in] handler Mismatch handler. +/// +/// @return `true` if the two matrices are considered matched. +bool compare( + const void* imp_data, const void* ref_data, const DataFormat& format, size_t full_height, size_t full_width, + const Rect& rect, MismatchHandler& handler); + +/// Handles mismatches found by @ref validate function. +class MismatchHandler { +public: + /// Constructor. + MismatchHandler() = default; + + /// Destructor. + virtual ~MismatchHandler() = default; + + /// Copy constructor. + MismatchHandler(const MismatchHandler&) = default; + + /// Copy assignment. + MismatchHandler& operator=(const MismatchHandler&) = default; + + /// Move constructor. + MismatchHandler(MismatchHandler&&) noexcept = default; + + /// Move assignment. + MismatchHandler& operator=(MismatchHandler&&) noexcept = default; + + /// Handles new mismatch result. + /// + /// This method must be called even when no error is detected. + /// + /// @param[in] absolute_error Absolute error. + /// @param[in] relative_error Relative error. + /// + /// @return `true` if the mismatch is sufficiently large to be logged as real mismatch. + virtual bool handle_data(float absolute_error, float relative_error) = 0; + + /// Marks the result as failed. + /// + /// It is zero tolerance if the data point is considered impossible to have mismatch + /// regardless of implementation method. + /// These normally include data point outside if the portion of interest (these must be 0) + /// and data point belongs to quantization information. + virtual void mark_as_failed() = 0; + + /// Returns a value indicating whether the two matrices are considered matched. + /// + /// @param[in] num_checks Total number of data points that have been checked. + /// + /// @return `true` if the two matrices are considered matched. + [[nodiscard]] virtual bool success(size_t num_checks) const = 0; +}; + +/// This mismatch handler considers two values being mismatched when both the relative error +/// and the absolute error exceed their respective thresholds. +/// +/// This mismatch handler considers two matrices being mismatched when the number of mismatches +/// exceed both the relative and absolute thresholds. +class DefaultMismatchHandler final : public MismatchHandler { +public: + /// Creates a new mismatch handler. + /// + /// @param[in] abs_error_threshold Threshold for absolute error + /// @param[in] rel_error_threshold Threshold for relative error. + /// @param[in] abs_mismatched_threshold Threshold for the number of mismatched data points. + /// @param[in] rel_mismatched_threshold Threshold for the ratio of mismatched data points. + DefaultMismatchHandler( + float abs_error_threshold, float rel_error_threshold, size_t abs_mismatched_threshold, + float rel_mismatched_threshold); + + /// Destructur. + ~DefaultMismatchHandler() = default; + + /// Copy constructor. + DefaultMismatchHandler(const DefaultMismatchHandler& rhs); + + /// Copy assignment. + DefaultMismatchHandler& operator=(const DefaultMismatchHandler& rhs); + + /// Move constructor. + DefaultMismatchHandler(DefaultMismatchHandler&& rhs) noexcept = default; + + /// Move assignment. + DefaultMismatchHandler& operator=(DefaultMismatchHandler&& rhs) noexcept = default; + + bool handle_data(float absolute_error, float relative_error) override; + void mark_as_failed() override; + [[nodiscard]] bool success(size_t num_checks) const override; + +private: + float _abs_error_threshold; + float _rel_error_threshold; + size_t _abs_mismatched_threshold; + float _rel_mismatched_threshold; + + size_t _num_mismatches; + bool _failed; +}; + +} // namespace kai::test diff --git a/test/common/matrix_portion.cpp b/test/common/matrix_portion.cpp new file mode 100644 index 00000000..abf795df --- /dev/null +++ b/test/common/matrix_portion.cpp @@ -0,0 +1,64 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/common/matrix_portion.hpp" + +#include +#include + +#include "test/common/rect.hpp" +#include "test/reference/round.hpp" + +namespace kai::test { + +MatrixPortion::MatrixPortion(float start_row, float start_col, float height, float width) : + _start_row(start_row), _start_col(start_col), _height(height), _width(width) { +} + +float MatrixPortion::start_row() const { + return _start_row; +} + +float MatrixPortion::start_col() const { + return _start_col; +} + +float MatrixPortion::height() const { + return _height; +} + +float MatrixPortion::width() const { + return _width; +} + +Rect MatrixPortion::compute_portion( + size_t full_height, size_t full_width, size_t scheduler_block_height, size_t scheduler_block_width) const { + const auto start_row_f = std::clamp(_start_row, 0, 1); + const auto start_col_f = std::clamp(_start_col, 0, 1); + const auto height_f = std::clamp(_height, 0, 1); + const auto width_f = std::clamp(_width, 0, 1); + + auto start_row = round_to_nearest_even_usize(start_row_f * static_cast(full_height)); + auto start_col = round_to_nearest_even_usize(start_col_f * static_cast(full_width)); + auto height = round_to_nearest_even_usize(height_f * static_cast(full_height)); + auto width = round_to_nearest_even_usize(width_f * static_cast(full_width)); + + start_row = round_down_multiple(start_row, scheduler_block_height); + start_col = round_down_multiple(start_col, scheduler_block_width); + + start_row = std::min(start_row, round_down_multiple(full_height, scheduler_block_height)); + start_col = std::min(start_col, round_down_multiple(full_width, scheduler_block_width)); + + height = round_up_multiple(height, scheduler_block_height); + width = round_up_multiple(width, scheduler_block_width); + + height = std::min(height, full_height - start_row); + width = std::min(width, full_width - start_col); + + return {start_row, start_col, height, width}; +} + +} // namespace kai::test diff --git a/test/common/matrix_portion.hpp b/test/common/matrix_portion.hpp new file mode 100644 index 00000000..02f9f548 --- /dev/null +++ b/test/common/matrix_portion.hpp @@ -0,0 +1,68 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#include "test/common/rect.hpp" + +namespace kai::test { + +/// Portion of a matrix. +/// +/// This class is used to define the sub-matrix a test is running and checking. +/// +/// This is the relative version of @ref Rect. +class MatrixPortion { +public: + /// Creates a new matrix portion. + /// + /// @param[in] start_row Starting row as the ratio to the height of the matrix. + /// @param[in] start_col Starting column as the ratio to the width of the matrix. + /// @param[in] height Portion height as the ratio to the height of the matrix. + /// @param[in] width Portion width as the ratio to the width of the matrix. + MatrixPortion(float start_row, float start_col, float height, float width); + + /// Gets the starting row as the ratio to the height of the matrix. + [[nodiscard]] float start_row() const; + + /// Gets the starting column as the ratio to the width of the matrix. + [[nodiscard]] float start_col() const; + + /// Gets the portion height as the ratio to the height of the matrix. + [[nodiscard]] float height() const; + + /// Gets the portion width as the ratio to the width of the matrix. + [[nodiscard]] float width() const; + + /// Computes the starting coordinate and the shape of the sub-matrix. + /// + /// Requirements: + /// + /// * The starting coordinate of the sub-matrix shall be aligned with the block boundary. + /// * If it is not the block at the right and/or bottom edge of the full matrix, the height and width + /// of the sub-matrix shall be rounded up to multiple of the block height and width. + /// * If it is the block at the right and/or bottom edge of the full matrix, the height and width + /// of the sub-matrix shall be the rounded up to the edge of the matrix. + /// + /// @param[in] full_height Matrix height. + /// @param[in] full_width Matrix width. + /// @param[in] scheduler_block_height Block height for scheduling purpose. + /// @param[in] scheduler_block_width Block width for scheduling purpose. + /// + /// @return The rectangular region of the matrix. + [[nodiscard]] Rect compute_portion( + size_t full_height, size_t full_width, size_t scheduler_block_height, size_t scheduler_block_width) const; + +private: + float _start_row; + float _start_col; + float _height; + float _width; +}; + +} // namespace kai::test diff --git a/test/common/rect.cpp b/test/common/rect.cpp new file mode 100644 index 00000000..457ac2f3 --- /dev/null +++ b/test/common/rect.cpp @@ -0,0 +1,41 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/common/rect.hpp" + +#include + +namespace kai::test { + +Rect::Rect(size_t start_row, size_t start_col, size_t height, size_t width) : + _start_row(start_row), _start_col(start_col), _height(height), _width(width) { +} + +size_t Rect::start_row() const { + return _start_row; +} + +size_t Rect::start_col() const { + return _start_col; +} + +size_t Rect::height() const { + return _height; +} + +size_t Rect::width() const { + return _width; +} + +size_t Rect::end_row() const { + return _start_row + _height; +} + +size_t Rect::end_col() const { + return _start_col + _width; +} + +} // namespace kai::test diff --git a/test/common/rect.hpp b/test/common/rect.hpp new file mode 100644 index 00000000..92b033b4 --- /dev/null +++ b/test/common/rect.hpp @@ -0,0 +1,51 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +namespace kai::test { + +/// Rectangular region of a matrix. +/// +/// This is the absolute version of @ref MatrixPortion. +class Rect { +public: + /// Creates a new rectangular region of a matrix. + /// + /// @param[in] start_row Starting row index. + /// @param[in] start_col Starting column index. + /// @param[in] height Number of rows. + /// @param[in] width Number of columns. + Rect(size_t start_row, size_t start_col, size_t height, size_t width); + + /// Gets the starting row index. + [[nodiscard]] size_t start_row() const; + + /// Gets the starting column index. + [[nodiscard]] size_t start_col() const; + + /// Gets the number of rows. + [[nodiscard]] size_t height() const; + + /// Gets the number of columns. + [[nodiscard]] size_t width() const; + + /// Gets the end (exclusive) row index. + [[nodiscard]] size_t end_row() const; + + /// Gets the end (exclusive) column index. + [[nodiscard]] size_t end_col() const; + +private: + size_t _start_row; + size_t _start_col; + size_t _height; + size_t _width; +}; + +} // namespace kai::test -- GitLab From 2ec5fc6a71a3ba517c4bebe403e56fe03f45e86f Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Tue, 30 Apr 2024 16:50:45 +0100 Subject: [PATCH 3/5] Fix GCC compilation error Signed-off-by: Viet-Hoa Do --- test/common/compare.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/common/compare.cpp b/test/common/compare.cpp index 206ec0e5..72eba283 100644 --- a/test/common/compare.cpp +++ b/test/common/compare.cpp @@ -186,6 +186,8 @@ bool compare( break; } + break; + case DataFormat::QuantizationFormat::PER_ROW: if (data_type == DataType::QI8 && scale_dt == DataType::FP32 && offset_dt == DataType::I32) { return compare_per_row( -- GitLab From 5be25ff33fb098b03abc3ef391c98f12b24cadcc Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Thu, 2 May 2024 10:11:11 +0100 Subject: [PATCH 4/5] Change data type QI8 to QAI8 Signed-off-by: Viet-Hoa Do --- test/common/data_type.hpp | 2 +- test/common/printer.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/common/data_type.hpp b/test/common/data_type.hpp index 8fbdcf9d..c707b53f 100644 --- a/test/common/data_type.hpp +++ b/test/common/data_type.hpp @@ -41,7 +41,7 @@ enum class DataType : uint16_t { I32 = 0b1'1'0'0'0000'00100000, ///< 32-bit signed integer. - QI8 = 0b1'1'1'1'0000'00001000, ///< 8-bit asymmetric quantized. + QAI8 = 0b1'1'1'1'0000'00001000, ///< 8-bit unsigned asymmetric quantized. QSU4 = 0b1'0'1'0'0000'00000100, ///< 4-bit unsigned symmetric quantized. QSI4 = 0b1'1'1'0'0000'00000100, ///< 4-bit signed symmetric quantized. diff --git a/test/common/printer.cpp b/test/common/printer.cpp index 368ae7db..9bc7a100 100644 --- a/test/common/printer.cpp +++ b/test/common/printer.cpp @@ -40,7 +40,7 @@ inline void print_data(std::ostream& os, const uint8_t* data, size_t len, DataTy os << reinterpret_cast(data)[i]; break; - case DataType::QI8: + case DataType::QAI8: os << static_cast(reinterpret_cast(data)[i]); break; -- GitLab From 581eae238c55404a18546f14ec81b2e4321732c8 Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Thu, 2 May 2024 10:27:08 +0100 Subject: [PATCH 5/5] Address various comments Signed-off-by: Viet-Hoa Do --- test/common/compare.cpp | 2 +- test/common/matrix_portion.cpp | 19 ++++++++++--------- test/common/matrix_portion.hpp | 10 +++++----- 3 files changed, 16 insertions(+), 15 deletions(-) diff --git a/test/common/compare.cpp b/test/common/compare.cpp index 72eba283..f1dfecdf 100644 --- a/test/common/compare.cpp +++ b/test/common/compare.cpp @@ -189,7 +189,7 @@ bool compare( break; case DataFormat::QuantizationFormat::PER_ROW: - if (data_type == DataType::QI8 && scale_dt == DataType::FP32 && offset_dt == DataType::I32) { + if (data_type == DataType::QAI8 && scale_dt == DataType::FP32 && offset_dt == DataType::I32) { return compare_per_row( imp_data, ref_data, format, full_height, full_width, rect, handler); } else if (data_type == DataType::QSI4 && scale_dt == DataType::FP32 && offset_dt == DataType::I32) { diff --git a/test/common/matrix_portion.cpp b/test/common/matrix_portion.cpp index abf795df..e7c42bf1 100644 --- a/test/common/matrix_portion.cpp +++ b/test/common/matrix_portion.cpp @@ -9,6 +9,7 @@ #include #include +#include "src/kai_common.h" #include "test/common/rect.hpp" #include "test/reference/round.hpp" @@ -36,15 +37,15 @@ float MatrixPortion::width() const { Rect MatrixPortion::compute_portion( size_t full_height, size_t full_width, size_t scheduler_block_height, size_t scheduler_block_width) const { - const auto start_row_f = std::clamp(_start_row, 0, 1); - const auto start_col_f = std::clamp(_start_col, 0, 1); - const auto height_f = std::clamp(_height, 0, 1); - const auto width_f = std::clamp(_width, 0, 1); - - auto start_row = round_to_nearest_even_usize(start_row_f * static_cast(full_height)); - auto start_col = round_to_nearest_even_usize(start_col_f * static_cast(full_width)); - auto height = round_to_nearest_even_usize(height_f * static_cast(full_height)); - auto width = round_to_nearest_even_usize(width_f * static_cast(full_width)); + KAI_ASSUME(_start_row >= 0.0F && _start_row <= 1.0F); + KAI_ASSUME(_start_col >= 0.0F && _start_col <= 1.0F); + KAI_ASSUME(_height >= 0.0F && _height <= 1.0F); + KAI_ASSUME(_width >= 0.0F && _width <= 1.0F); + + auto start_row = round_to_nearest_even_usize(_start_row * static_cast(full_height)); + auto start_col = round_to_nearest_even_usize(_start_col * static_cast(full_width)); + auto height = round_to_nearest_even_usize(_height * static_cast(full_height)); + auto width = round_to_nearest_even_usize(_width * static_cast(full_width)); start_row = round_down_multiple(start_row, scheduler_block_height); start_col = round_down_multiple(start_col, scheduler_block_width); diff --git a/test/common/matrix_portion.hpp b/test/common/matrix_portion.hpp index 02f9f548..68193a40 100644 --- a/test/common/matrix_portion.hpp +++ b/test/common/matrix_portion.hpp @@ -14,7 +14,7 @@ namespace kai::test { /// Portion of a matrix. /// -/// This class is used to define the sub-matrix a test is running and checking. +/// This class is used to define the sub-matrix under test. /// /// This is the relative version of @ref Rect. class MatrixPortion { @@ -43,10 +43,10 @@ public: /// /// Requirements: /// - /// * The starting coordinate of the sub-matrix shall be aligned with the block boundary. - /// * If it is not the block at the right and/or bottom edge of the full matrix, the height and width - /// of the sub-matrix shall be rounded up to multiple of the block height and width. - /// * If it is the block at the right and/or bottom edge of the full matrix, the height and width + /// * The starting coordinate of the sub-matrix shall be aligned with the scheduling block boundary. + /// * If it is not the scheduling block at the right and/or bottom edge of the full matrix, the height and width + /// of the sub-matrix shall be rounded up to multiple of the scheduling block height and width. + /// * If it is the scheduling block at the right and/or bottom edge of the full matrix, the height and width /// of the sub-matrix shall be the rounded up to the edge of the matrix. /// /// @param[in] full_height Matrix height. -- GitLab