From 133a04448a87e74ca22d0101a3cd947fef7d918c Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Tue, 30 Apr 2024 12:42:29 +0100 Subject: [PATCH 1/3] Create test framework * Add common concepts: - Data type. - Data format. * Add custom data types: - 4-bit signed integer. - 4-bit unsigned integer. * Add generic type utilities: - Type traits. - Numeric limits. - Memory access. * Add matrix printer for arbitrary data format. * Add reference implementation of rounding functions. Signed-off-by: Viet-Hoa Do --- .clang-tidy | 5 ++ CMakeLists.txt | 34 ++++++-- src/kai_common.h | 42 ++++++++++ test/common/data_format.cpp | 144 +++++++++++++++++++++++++++++++++ test/common/data_format.hpp | 126 +++++++++++++++++++++++++++++ test/common/data_type.cpp | 81 +++++++++++++++++++ test/common/data_type.hpp | 110 +++++++++++++++++++++++++ test/common/int4.cpp | 98 ++++++++++++++++++++++ test/common/int4.hpp | 112 +++++++++++++++++++++++++ test/common/logging.hpp | 68 ++++++++++++++++ test/common/memory.hpp | 77 ++++++++++++++++++ test/common/numeric_limits.hpp | 39 +++++++++ test/common/printer.cpp | 112 +++++++++++++++++++++++++ test/common/printer.hpp | 28 +++++++ test/common/type_traits.hpp | 79 ++++++++++++++++++ test/reference/round.cpp | 36 +++++++++ test/reference/round.hpp | 62 ++++++++++++++ 17 files changed, 1246 insertions(+), 7 deletions(-) create mode 100644 src/kai_common.h create mode 100644 test/common/data_format.cpp create mode 100644 test/common/data_format.hpp create mode 100644 test/common/data_type.cpp create mode 100644 test/common/data_type.hpp create mode 100644 test/common/int4.cpp create mode 100644 test/common/int4.hpp create mode 100644 test/common/logging.hpp create mode 100644 test/common/memory.hpp create mode 100644 test/common/numeric_limits.hpp create mode 100644 test/common/printer.cpp create mode 100644 test/common/printer.hpp create mode 100644 test/common/type_traits.hpp create mode 100644 test/reference/round.cpp create mode 100644 test/reference/round.hpp diff --git a/.clang-tidy b/.clang-tidy index 7cff2942..bebfc13a 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -22,5 +22,10 @@ readability-*, -readability-identifier-length, -readability-magic-numbers, -readability-function-cognitive-complexity, +-cppcoreguidelines-pro-type-reinterpret-cast, +-cppcoreguidelines-avoid-magic-numbers, +-readability-simplify-boolean-expr, +-bugprone-easily-swappable-parameters, +-cppcoreguidelines-pro-bounds-pointer-arithmetic ' ... diff --git a/CMakeLists.txt b/CMakeLists.txt index d45d5070..b6d9662c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -31,9 +31,7 @@ endif() set(KLEIDIAI_WARNING_FLAGS "-Wall" - "-Wctor-dtor-privacy" "-Wdisabled-optimization" - "-Weffc++" "-Werror" "-Wextra" "-Wformat-security" @@ -42,19 +40,41 @@ set(KLEIDIAI_WARNING_FLAGS "-Wno-ignored-attributes" "-Wno-misleading-indentation" "-Wno-overlength-strings" - "-Woverloaded-virtual" - "-Wsign-promo" "-Wstrict-overflow=2" "-Wswitch-default" ) +set(KLEIDIAI_WARNING_FLAGS_CXX + "-Wctor-dtor-privacy" + "-Weffc++" + "-Woverloaded-virtual" + "-Wsign-promo" +) + if(KLEIDIAI_BUILD_TESTS) enable_testing() include(GoogleTest) - add_executable(kleidiai_test test/sample.cpp) - target_compile_options(kleidiai_test PRIVATE ${KLEIDIAI_WARNING_FLAGS}) - target_link_libraries(kleidiai_test PRIVATE GTest::gtest_main) + add_executable(kleidiai_test + test/common/data_type.cpp + test/common/data_format.cpp + test/common/printer.cpp + test/common/int4.cpp + test/reference/round.cpp + ) + + target_include_directories(kleidiai_test + PRIVATE . + ) + + target_compile_options(kleidiai_test + PRIVATE ${KLEIDIAI_WARNING_FLAGS} + PRIVATE ${KLEIDIAI_WARNING_FLAGS_CXX} + ) + + target_link_libraries(kleidiai_test + PRIVATE GTest::gtest_main + ) # Cross-compiling is a common use case which creates a conflict if DISCOVERY_MODE is set to POST_BUILD (by default) # since the host platform does not match the target. Setting the mode to PRE_TEST avoids this conflict. diff --git a/src/kai_common.h b/src/kai_common.h new file mode 100644 index 00000000..b44dad5b --- /dev/null +++ b/src/kai_common.h @@ -0,0 +1,42 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +// NOLINTBEGIN(cppcoreguidelines-avoid-do-while,cppcoreguidelines-pro-type-vararg,cert-err33-c) +// +// * cppcoreguidelines-avoid-do-while: do-while is necessary for macros. +// * cppcoreguidelines-pro-type-vararg: use of variadic arguments in fprintf is expected. +// * cert-err33-c: checking the output of fflush and fprintf is not necessary for error reporting. +#define KAI_ERROR(msg) \ + do { \ + fflush(stdout); \ + fprintf(stderr, "%s", msg); \ + exit(EXIT_FAILURE); \ + } while (0) + +#define KAI_ASSERT_MSG(cond, msg) \ + do { \ + if (!(cond)) { \ + KAI_ERROR(msg); \ + } \ + } while (0) +// NOLINTEND(cppcoreguidelines-avoid-do-while,cppcoreguidelines-pro-type-vararg,cert-err33-c) + +#define KAI_ASSERT(cond) KAI_ASSERT_MSG(cond, #cond) + +#define KAI_ASSERT_IF_MSG(precond, cond, msg) KAI_ASSERT_MSG(!(precond) || (cond), msg) +#define KAI_ASSERT_IF(precond, cond) KAI_ASSERT_IF_MSG(precond, cond, #precond " |-> " #cond) + +#define KAI_ASSUME_MSG KAI_ASSERT_MSG +#define KAI_ASSUME KAI_ASSERT +#define KAI_ASSUME_IF_MSG KAI_ASSERT_IF_MSG +#define KAI_ASSUME_IF KAI_ASSERT_IF + +#define KAI_UNUSED(x) (void)(x) diff --git a/test/common/data_format.cpp b/test/common/data_format.cpp new file mode 100644 index 00000000..1673b405 --- /dev/null +++ b/test/common/data_format.cpp @@ -0,0 +1,144 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/common/data_format.hpp" + +#include +#include + +#include "src/kai_common.h" +#include "test/common/data_type.hpp" +#include "test/reference/round.hpp" + +namespace kai::test { + +DataFormat::DataFormat( + DataType data_type, size_t block_height, size_t block_width, QuantizationFormat quant_format, DataType scale_dt, + DataType zero_point_dt, size_t subblock_height, size_t subblock_width) noexcept : + _data_type(data_type), + _quant_format(quant_format), + _scale_dt(scale_dt), + _zero_point_dt(zero_point_dt), + _block_height(block_height), + _block_width(block_width), + _subblock_height(subblock_height), + _subblock_width(subblock_width) { +} + +bool DataFormat::operator==(const DataFormat& rhs) const { + return _data_type == rhs._data_type && _quant_format == rhs._quant_format && _scale_dt == rhs._scale_dt && + _zero_point_dt == rhs._zero_point_dt && _block_height == rhs._block_height && _block_width == rhs._block_width; +} + +bool DataFormat::operator!=(const DataFormat& rhs) const { + return !(*this == rhs); +} + +DataType DataFormat::data_type() const { + return _data_type; +} + +DataFormat::QuantizationFormat DataFormat::quantization_format() const { + return _quant_format; +} + +DataType DataFormat::scale_data_type() const { + return _scale_dt; +} + +DataType DataFormat::zero_point_data_type() const { + return _zero_point_dt; +} + +bool DataFormat::is_raw() const { + return _quant_format == QuantizationFormat::NONE && // + _block_height == 0 && _block_width == 0 && _subblock_height == 0 && _subblock_width == 0; +} + +size_t DataFormat::block_height() const { + return _block_height; +} + +size_t DataFormat::block_width() const { + return _block_width; +} + +size_t DataFormat::subblock_height() const { + return _subblock_height; +} + +size_t DataFormat::subblock_width() const { + return _subblock_width; +} + +size_t DataFormat::scheduler_block_height([[maybe_unused]] size_t full_height) const { + switch (_quant_format) { + case QuantizationFormat::NONE: + return _block_height > 0 ? _block_height : 1; + + case QuantizationFormat::PER_ROW: + return _block_height; + + default: + KAI_ERROR("Unsupported quantization packing format!"); + } +} + +size_t DataFormat::scheduler_block_width(size_t full_width) const { + switch (_quant_format) { + case QuantizationFormat::NONE: + return _block_width > 0 ? _block_width : 1; + + case QuantizationFormat::PER_ROW: + return full_width; + + default: + KAI_ERROR("Unsupported quantization packing format!"); + } +} + +uintptr_t DataFormat::default_row_stride(size_t width) const { + const auto padded_width = round_up_multiple(width, _block_width > 0 ? _block_width : 1); + + switch (_quant_format) { + case QuantizationFormat::NONE: + return padded_width * data_type_size_in_bits(_data_type) / 8; + + case QuantizationFormat::PER_ROW: + return _block_height * data_type_size_in_bits(_zero_point_dt) / 8 + // + _block_height * padded_width * data_type_size_in_bits(_data_type) / 8 + // + _block_height * data_type_size_in_bits(_scale_dt) / 8; + + default: + KAI_ERROR("Unsupported quantization packing format!"); + } +} + +uintptr_t DataFormat::default_offset_in_bytes(size_t row, size_t col, size_t width) const { + const auto row_stride = default_row_stride(width); + + KAI_ASSERT(col % scheduler_block_width(width) == 0); + + switch (_quant_format) { + case QuantizationFormat::NONE: + return row * row_stride + col * data_type_size_in_bits(_data_type) / 8; + + case QuantizationFormat::PER_ROW: + KAI_ASSERT(row % _block_height == 0); + KAI_ASSERT(col == 0); + return (row / _block_height) * row_stride + col * data_type_size_in_bits(_data_type) / 8; + + default: + KAI_ERROR("Unsupported quantization packing format!"); + } +} + +size_t DataFormat::default_size_in_bytes(size_t height, size_t width) const { + const auto num_rows = _block_height > 0 ? (height + _block_height - 1) / _block_height : height; + return num_rows * default_row_stride(width); +} + +} // namespace kai::test diff --git a/test/common/data_format.hpp b/test/common/data_format.hpp new file mode 100644 index 00000000..3aa19f71 --- /dev/null +++ b/test/common/data_format.hpp @@ -0,0 +1,126 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#include "test/common/data_type.hpp" + +namespace kai::test { + +/// Data format. +class DataFormat { +public: + /// Quantization packing format. + enum class QuantizationFormat : uint32_t { + NONE, ///< No quantization information is included. + PER_ROW, ///< Per-row quantization. + }; + + /// Creates a new data format. + /// + /// @param[in] data_type Data type of data value. + /// @param[in] block_height Block height. + /// @param[in] block_width Block width. + /// @param[in] quant_format Quantization packing format. + /// @param[in] scale_dt Data type of scale value. + /// @param[in] zero_point_dt Data type of zero point value. + /// @param[in] subblock_height Sub-block height. + /// @param[in] subblock_width Sub-block width. + DataFormat( + DataType data_type, size_t block_height = 0, size_t block_width = 0, + QuantizationFormat quant_format = QuantizationFormat::NONE, DataType scale_dt = DataType::UNKNOWN, + DataType zero_point_dt = DataType::UNKNOWN, size_t subblock_height = 0, size_t subblock_width = 0) noexcept; + + /// Equality operator. + [[nodiscard]] bool operator==(const DataFormat& rhs) const; + + /// Unequality operator. + [[nodiscard]] bool operator!=(const DataFormat& rhs) const; + + /// Gets the quantization packing format. + [[nodiscard]] QuantizationFormat quantization_format() const; + + /// Gets the data type of data value. + [[nodiscard]] DataType data_type() const; + + /// Gets the data type of scale value. + [[nodiscard]] DataType scale_data_type() const; + + /// Gets the data type of zero point value. + [[nodiscard]] DataType zero_point_data_type() const; + + /// Gets a value indicating whether this format has no blocking or packed quantization information. + [[nodiscard]] bool is_raw() const; + + /// Gets the block height. + [[nodiscard]] size_t block_height() const; + + /// Gets the block width. + [[nodiscard]] size_t block_width() const; + + /// Gets the sub-block height. + [[nodiscard]] size_t subblock_height() const; + + /// Gets the sub-block width. + [[nodiscard]] size_t subblock_width() const; + + /// Gets the scheduling block height. + /// + /// @param[in] full_height Height of the full matrix. + /// + /// @return The block height for scheduling purpose. + [[nodiscard]] size_t scheduler_block_height(size_t full_height) const; + + /// Gets the scheduling block width. + /// + /// @param[in] full_width Width of the full matrix. + /// + /// @return The block width for scheduling purpose. + [[nodiscard]] size_t scheduler_block_width(size_t full_width) const; + + /// Gets the row stride in bytes given the data is stored continuously without any gap in the memory. + /// + /// In case of per-row quantization, the row stride is the number of bytes from one row group + /// to the next. One row group consists of `block_height` rows. + /// + /// @param[in] width Width of the full matrix. + /// + /// @return The default row stride in bytes of the matrix. + [[nodiscard]] uintptr_t default_row_stride(size_t width) const; + + /// Gets the offsets in bytes in the data buffer given the data is stored continuously + /// without any gap in the memory. + /// + /// @param[in] row Row coordinate. + /// @param[in] col Colum coordinate. + /// @param[in] width Width of the full matrix. + /// + /// @return The default offset in bytes. + [[nodiscard]] uintptr_t default_offset_in_bytes(size_t row, size_t col, size_t width) const; + + /// Gets the size in bytes of the matrix given the data is stored continuously without any gap in the memory. + /// + /// @param[in] height Height of the full matrix. + /// @param[in] width Width of the full matrix. + /// + /// @return The size in bytes of the matrix. + [[nodiscard]] size_t default_size_in_bytes(size_t height, size_t width) const; + +private: + DataType _data_type; + QuantizationFormat _quant_format; + DataType _scale_dt; + DataType _zero_point_dt; + size_t _block_height; + size_t _block_width; + size_t _subblock_height; + size_t _subblock_width; +}; + +} // namespace kai::test diff --git a/test/common/data_type.cpp b/test/common/data_type.cpp new file mode 100644 index 00000000..79cc7e7c --- /dev/null +++ b/test/common/data_type.cpp @@ -0,0 +1,81 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/common/data_type.hpp" + +#include +#include + +#include "src/kai_common.h" + +namespace kai::test { + +namespace { + +bool has_i(DataType dt) { + return (static_cast(dt) & (1 << 15)) != 0; +} + +bool has_s(DataType dt) { + return (static_cast(dt) & (1 << 14)) != 0; +} + +bool has_q(DataType dt) { + return (static_cast(dt) & (1 << 13)) != 0; +} + +bool has_a(DataType dt) { + return (static_cast(dt) & (1 << 12)) != 0; +} + +size_t bits(DataType dt) { + return static_cast(dt) & 0xFF; +} + +} // namespace + +size_t data_type_size_in_bits(DataType dt) { + return bits(dt); +} + +bool data_type_is_integral(DataType dt) { + return has_i(dt); +} + +bool data_type_is_float(DataType dt) { + KAI_ASSERT(data_type_is_signed(dt)); + return !data_type_is_integral(dt); +} + +bool data_type_is_float_fp(DataType dt) { + KAI_ASSERT(data_type_is_float(dt)); + return !has_q(dt); +} + +bool data_type_is_float_bf(DataType dt) { + KAI_ASSERT(data_type_is_float(dt)); + return has_q(dt); +} + +bool data_type_is_signed(DataType dt) { + if (!has_s(dt)) { + KAI_ASSERT(data_type_is_integral(dt)); + } + + return has_s(dt); +} + +bool data_type_is_quantized(DataType dt) { + KAI_ASSERT_IF(has_q(dt), data_type_is_integral(dt)); + return has_q(dt); +} + +bool data_type_is_quantized_asymm(DataType dt) { + KAI_ASSERT_IF(has_a(dt), data_type_is_quantized(dt)); + return has_a(dt); +} + +} // namespace kai::test diff --git a/test/common/data_type.hpp b/test/common/data_type.hpp new file mode 100644 index 00000000..8fbdcf9d --- /dev/null +++ b/test/common/data_type.hpp @@ -0,0 +1,110 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +namespace kai::test { + +/// Data type. +enum class DataType : uint16_t { + // Encoding: + // + // 15 0 + // +---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+ + // | i | s | q | a | RES0 | bits | + // +---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+ + // + // (RES0: reserved, filled with 0s) + // + // Fields: + // + // * i: integer (1) or floating-point (0). + // * s: signed (1) or unsigned (0). + // * q: + // - Integer (i): quantized (1) or non-quantized (0). + // - Floating-point (!i): brain (1) or binary (0). + // * a: + // - Quantized (i && q): asymmetric (1) or symmetric (0). + // - Otherwise: RES0. + // * bits: size in bits. + + UNKNOWN = 0, ///< No data. + + FP32 = 0b0'1'0'0'0000'00100000, ///< Single-precision floating-point. + FP16 = 0b0'1'0'0'0000'00010000, ///< Half-precision floating-point. + + I32 = 0b1'1'0'0'0000'00100000, ///< 32-bit signed integer. + + QI8 = 0b1'1'1'1'0000'00001000, ///< 8-bit asymmetric quantized. + + QSU4 = 0b1'0'1'0'0000'00000100, ///< 4-bit unsigned symmetric quantized. + QSI4 = 0b1'1'1'0'0000'00000100, ///< 4-bit signed symmetric quantized. +}; + +/// Gets the size in bits of the specified data type. +/// +/// @param[in] dt The data type. +/// +/// @return The size in bits. +[[nodiscard]] size_t data_type_size_in_bits(DataType dt); + +/// Gets a value indicating whether the data type is integral. +/// +/// @param[in] dt The data type. +/// +/// @return `true` if the data type is integral. +[[nodiscard]] bool data_type_is_integral(DataType dt); + +/// Gets a value indicating whether the data type is floating-point. +/// +/// @param[in] dt The data type. +/// +/// @return `true` if the data type is floating-point. +[[nodiscard]] bool data_type_is_float(DataType dt); + +/// Gets a value indicating whether the data type is binary floating-point. +/// +/// Binary floating point are `half`, `float`, `double`. +/// +/// @param[in] dt The data type. +/// +/// @return `true` if the data type is binary floating-point. +[[nodiscard]] bool data_type_is_float_fp(DataType dt); + +/// Gets a value indicating whether the data type is brain floating-point. +/// +/// Binary floating point are `bfloat16`. +/// +/// @param[in] dt The data type. +/// +/// @return `true` if the data type is brain floating-point. +[[nodiscard]] bool data_type_is_float_bf(DataType dt); + +/// Gets a value indicating whether the data type is signed. +/// +/// @param[in] dt The data type. +/// +/// @return `true` if the data type is signed. +[[nodiscard]] bool data_type_is_signed(DataType dt); + +/// Gets a value indicating whether the data type is quantized. +/// +/// @param[in] dt The data type. +/// +/// @return `true` if the data type is quantized. +[[nodiscard]] bool data_type_is_quantized(DataType dt); + +/// Gets a value indicating whether the data type is asymmetric quantized. +/// +/// @param[in] dt The data type. +/// +/// @return `true` if the data type is asymmetric quantized. +[[nodiscard]] bool data_type_is_quantized_asymm(DataType dt); + +} // namespace kai::test diff --git a/test/common/int4.cpp b/test/common/int4.cpp new file mode 100644 index 00000000..bc9563b1 --- /dev/null +++ b/test/common/int4.cpp @@ -0,0 +1,98 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/common/int4.hpp" + +#include +#include + +namespace kai::test { + +UInt4& UInt4::operator=(uint8_t value) { + _value = value; + return *this; +} + +UInt4::operator int32_t() const { + return _value; +} + +UInt4::operator float() const { + return _value; +} + +UInt4 UInt4::operator+(UInt4 rhs) const { + return UInt4(_value + rhs._value); +} + +UInt4 UInt4::operator-(UInt4 rhs) const { + return UInt4(_value - rhs._value); +} + +UInt4 UInt4::operator*(UInt4 rhs) const { + return UInt4(_value * rhs._value); +} + +UInt4 UInt4::operator/(UInt4 rhs) const { + return UInt4(_value / rhs._value); +} + +uint8_t UInt4::pack_u8(UInt4 low, UInt4 high) { + return (low._value & 0x0F) | (high._value << 4); +} + +std::tuple UInt4::unpack_u8(uint8_t value) { + const uint8_t low = value & 0x0F; + const uint8_t high = value >> 4; + + return {UInt4(low), UInt4(high)}; +} + +// ===================================================================================================================== + +Int4& Int4::operator=(int8_t value) { + _value = value; + return *this; +} + +Int4::operator int32_t() const { + return _value; +} + +Int4::operator float() const { + return _value; +} + +Int4 Int4::operator+(Int4 rhs) const { + return Int4(static_cast(_value + rhs._value)); +} + +Int4 Int4::operator-(Int4 rhs) const { + return Int4(static_cast(_value - rhs._value)); +} + +Int4 Int4::operator*(Int4 rhs) const { + return Int4(static_cast(_value * rhs._value)); +} + +Int4 Int4::operator/(Int4 rhs) const { + return Int4(static_cast(_value / rhs._value)); +} + +uint8_t Int4::pack_u8(Int4 low, Int4 high) { + return (low._value & 0x0F) | (high._value << 4); +} + +std::tuple Int4::unpack_u8(uint8_t value) { + // NOLINTBEGIN(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + const int8_t low = static_cast(value << 4) >> 4; + const int8_t high = static_cast(value) >> 4; + // NOLINTEND(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + + return {Int4(low), Int4(high)}; +} + +} // namespace kai::test diff --git a/test/common/int4.hpp b/test/common/int4.hpp new file mode 100644 index 00000000..169c5817 --- /dev/null +++ b/test/common/int4.hpp @@ -0,0 +1,112 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +namespace kai::test { + +/// 4-bit unsigned integer. +class UInt4 { +public: + /// Creates a new 4-bit unsigned integer value. + /// + /// @param[in] value Value. + constexpr explicit UInt4(uint8_t value) : _value(value) { + } + + /// Assignment operator. + UInt4& operator=(uint8_t value); + + /// Conversion operator. + operator int32_t() const; + + /// Conversion operator. + operator float() const; + + /// Addition operator. + [[nodiscard]] UInt4 operator+(UInt4 rhs) const; + + /// Subtraction operator. + [[nodiscard]] UInt4 operator-(UInt4 rhs) const; + + /// Multiplication operator. + [[nodiscard]] UInt4 operator*(UInt4 rhs) const; + + /// Division operator. + [[nodiscard]] UInt4 operator/(UInt4 rhs) const; + + /// Packs two 4-bit unsigned integer values into one byte. + /// + /// @param[in] low Low nibble. + /// @param[in] high High nibble. + /// + /// @return The packed byte. + [[nodiscard]] static uint8_t pack_u8(UInt4 low, UInt4 high); + + /// Unpacks one byte to two 4-bit unsigned integer values. + /// + /// @param[in] value 8-bit packed value. + /// + /// @return The low and high nibbles. + [[nodiscard]] static std::tuple unpack_u8(uint8_t value); + +private: + uint8_t _value; +}; + +/// 4-bit signed integer. +class Int4 { +public: + /// Creates a new 4-bit signed integer value. + /// + /// @param[in] value Value. + constexpr explicit Int4(int8_t value) : _value(value) { + } + + /// Assignment operator. + Int4& operator=(int8_t value); + + /// Conversion operator. + explicit operator int32_t() const; + + /// Conversion operator. + explicit operator float() const; + + /// Addition operator. + [[nodiscard]] Int4 operator+(Int4 rhs) const; + + /// Subtraction operator. + [[nodiscard]] Int4 operator-(Int4 rhs) const; + + /// Multiplication operator. + [[nodiscard]] Int4 operator*(Int4 rhs) const; + + /// Division operator. + [[nodiscard]] Int4 operator/(Int4 rhs) const; + + /// Packs two 4-bit signed integer values into one byte. + /// + /// @param[in] low Low nibble. + /// @param[in] high High nibble. + /// + /// @return The packed byte. + [[nodiscard]] static uint8_t pack_u8(Int4 low, Int4 high); + + /// Unpacks one byte to two 4-bit signed integer values. + /// + /// @param[in] value 8-bit packed value. + /// + /// @return The low and high nibbles. + [[nodiscard]] static std::tuple unpack_u8(uint8_t value); + +private: + int8_t _value; +}; + +} // namespace kai::test diff --git a/test/common/logging.hpp b/test/common/logging.hpp new file mode 100644 index 00000000..ff6130a0 --- /dev/null +++ b/test/common/logging.hpp @@ -0,0 +1,68 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include + +#include "test/common/int4.hpp" + +#define KAI_LOGE(...) kai::test::detail::log("ERROR", __VA_ARGS__) + +namespace kai::test::detail { + +/// Prints the specified value to standard error. +/// +/// @tparam T Data type. +/// +/// @param[in] value Value to be printed out. +template +void write_log_content(T&& value) { + using TT = std::decay_t; + + if constexpr (std::is_same_v) { + std::cerr << static_cast(value); + } else if constexpr (std::is_same_v) { + std::cerr << static_cast(value); + } else if constexpr (std::is_same_v) { + std::cerr << static_cast(value); + } else if constexpr (std::is_same_v) { + std::cerr << static_cast(value); + } else { + std::cerr << value; + } +} + +/// Prints the specified values to standard error. +/// +/// @tparam T Data type of the first value. +/// @tparam Ts Data types of the subsequent values. +/// +/// @param[in] value First value to be printed out. +/// @param[in] others Subsequent values to be printed out. +template +void write_log_content(T&& value, Ts&&... others) { + write_log_content(std::forward(value)); + write_log_content(std::forward(others)...); +} + +/// Prints the log to standard error. +/// +/// @tparam Ts Data types of values to be printed out. +/// +/// @param[in] level Severity level. +/// @param[in] args Values to be printed out. +template +void log(std::string_view level, Ts&&... args) { + std::cerr << level << " | "; + write_log_content(std::forward(args)...); + std::cerr << "\n"; +} + +} // namespace kai::test::detail diff --git a/test/common/memory.hpp b/test/common/memory.hpp new file mode 100644 index 00000000..bf5fbb01 --- /dev/null +++ b/test/common/memory.hpp @@ -0,0 +1,77 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#include "test/common/int4.hpp" + +namespace kai::test { + +/// The size in bits of type `T`. +template +inline constexpr size_t size_in_bits = sizeof(T) * 8; + +/// The size in bits of type `T`. +template <> +inline constexpr size_t size_in_bits = 4; + +/// The size in bits of type `T`. +template <> +inline constexpr size_t size_in_bits = 4; + +/// Reads the array at the specified index. +/// +/// @param[in] array Data buffer. +/// @param[in] index Array index. +/// +/// @return The array value at the specified index. +template +T read_array(const void* array, size_t index) { + if constexpr (std::is_same_v) { + const auto [lo, hi] = UInt4::unpack_u8(reinterpret_cast(array)[index / 2]); + return index % 2 == 0 ? lo : hi; + } else if constexpr (std::is_same_v) { + const auto [lo, hi] = Int4::unpack_u8(reinterpret_cast(array)[index / 2]); + return index % 2 == 0 ? lo : hi; + } else { + return reinterpret_cast(array)[index]; + } +} + +/// Writes the specified value to the array. +/// +/// @param[in] array Data buffer. +/// @param[in] index Array index. +/// @param[in] value Value to be stored. +template +void write_array(void* array, size_t index, T value) { + if constexpr (std::is_same_v) { + auto* arr_value = reinterpret_cast(array) + index / 2; + const auto [lo, hi] = UInt4::unpack_u8(*arr_value); + + if (index % 2 == 0) { + *arr_value = UInt4::pack_u8(value, hi); + } else { + *arr_value = UInt4::pack_u8(lo, value); + } + } else if constexpr (std::is_same_v) { + auto* arr_value = reinterpret_cast(array) + index / 2; + const auto [lo, hi] = Int4::unpack_u8(*arr_value); + + if (index % 2 == 0) { + *arr_value = Int4::pack_u8(value, hi); + } else { + *arr_value = Int4::pack_u8(lo, value); + } + } else { + reinterpret_cast(array)[index] = value; + } +} + +} // namespace kai::test diff --git a/test/common/numeric_limits.hpp b/test/common/numeric_limits.hpp new file mode 100644 index 00000000..e5950810 --- /dev/null +++ b/test/common/numeric_limits.hpp @@ -0,0 +1,39 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#include "test/common/int4.hpp" + +namespace kai::test { + +/// Highest finite value of type `T`. +template +inline constexpr T numeric_highest = std::numeric_limits::max(); + +/// Highest finite value of type `T`. +template <> +inline constexpr UInt4 numeric_highest{15}; + +/// Highest finite value of type `T`. +template <> +inline constexpr Int4 numeric_highest{7}; + +/// Lowest finite value of type `T`. +template +inline constexpr T numeric_lowest = std::numeric_limits::lowest(); + +/// Lowest finite value of type `T`. +template <> +inline constexpr UInt4 numeric_lowest{0}; + +/// Lowest finite value of type `T`. +template <> +inline constexpr Int4 numeric_lowest{-8}; + +} // namespace kai::test diff --git a/test/common/printer.cpp b/test/common/printer.cpp new file mode 100644 index 00000000..368ae7db --- /dev/null +++ b/test/common/printer.cpp @@ -0,0 +1,112 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include + +#include "src/kai_common.h" +#include "test/common/data_format.hpp" +#include "test/common/data_type.hpp" +#include "test/common/int4.hpp" + +namespace kai::test { + +namespace { + +inline void print_data(std::ostream& os, const uint8_t* data, size_t len, DataType data_type) { + if (data_type == DataType::QSU4) { + for (size_t i = 0; i < len / 2; ++i) { + const auto [low, high] = UInt4::unpack_u8(data[i]); + os << static_cast(low) << ", " << static_cast(high) << ", "; + } + } else if (data_type == DataType::QSI4) { + for (size_t i = 0; i < len / 2; ++i) { + const auto [low, high] = Int4::unpack_u8(data[i]); + os << static_cast(low) << ", " << static_cast(high) << ", "; + } + } else { + for (size_t i = 0; i < len; ++i) { + switch (data_type) { + case DataType::FP32: + os << reinterpret_cast(data)[i]; + break; + + case DataType::I32: + os << reinterpret_cast(data)[i]; + break; + + case DataType::QI8: + os << static_cast(reinterpret_cast(data)[i]); + break; + + default: + KAI_ERROR("Unsupported data type!"); + } + + os << ", "; + } + } +} + +void print_matrix_raw(std::ostream& os, const uint8_t* data, DataType data_type, size_t height, size_t width) { + const auto row_stride = width * data_type_size_in_bits(data_type) / 8; + + os << "[\n"; + for (size_t y = 0; y < height; ++y) { + os << " ["; + print_data(os, data + y * row_stride, width, data_type); + os << "],\n"; + } + os << "]\n"; +} + +void print_matrix_per_row( + std::ostream& os, const uint8_t* data, const DataFormat& format, size_t height, size_t width) { + const auto block_height = format.block_height(); + const auto num_blocks = (height + block_height - 1) / block_height; + + const auto block_data_bytes = block_height * width * data_type_size_in_bits(format.data_type()) / 8; + const auto block_offsets_bytes = block_height * data_type_size_in_bits(format.zero_point_data_type()) / 8; + const auto block_scales_bytes = block_height * data_type_size_in_bits(format.scale_data_type()) / 8; + + os << "[\n"; + for (size_t y = 0; y < num_blocks; ++y) { + os << " {\"offsets\": ["; + print_data(os, data, block_height, format.zero_point_data_type()); + os << "], \"data\": ["; + print_data(os, data + block_offsets_bytes, block_height * width, format.data_type()); + os << "], \"scales\": ["; + print_data(os, data + block_offsets_bytes + block_data_bytes, block_height, format.scale_data_type()); + os << "]},\n"; + + data += block_offsets_bytes + block_data_bytes + block_scales_bytes; + } + os << "]\n"; +} + +} // namespace + +void print_matrix( + std::ostream& os, std::string_view name, const void* data, const DataFormat& format, size_t height, size_t width) { + os << name << " = "; + + switch (format.quantization_format()) { + case DataFormat::QuantizationFormat::NONE: + print_matrix_raw(os, reinterpret_cast(data), format.data_type(), height, width); + break; + + case DataFormat::QuantizationFormat::PER_ROW: + print_matrix_per_row(os, reinterpret_cast(data), format, height, width); + break; + + default: + KAI_ERROR("Unsupported quantization packing format!"); + } +} + +} // namespace kai::test diff --git a/test/common/printer.hpp b/test/common/printer.hpp new file mode 100644 index 00000000..ec6213ae --- /dev/null +++ b/test/common/printer.hpp @@ -0,0 +1,28 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include + +namespace kai::test { + +class DataFormat; + +/// Prints the matrix data to the output stream. +/// +/// @param[in] os Output stream to write the data to. +/// @param[in] name Matrix name. +/// @param[in] data Data buffer. +/// @param[in] format Data format. +/// @param[in] height Number of rows. +/// @param[in] width Number of columns. +void print_matrix( + std::ostream& os, std::string_view name, const void* data, const DataFormat& format, size_t height, size_t width); + +} // namespace kai::test diff --git a/test/common/type_traits.hpp b/test/common/type_traits.hpp new file mode 100644 index 00000000..2267432a --- /dev/null +++ b/test/common/type_traits.hpp @@ -0,0 +1,79 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +namespace kai::test { + +class UInt4; +class Int4; + +/// `true` if `T` is unsigned numeric type. +template +inline constexpr bool is_unsigned = std::is_unsigned_v; + +/// `true` if `T` is unsigned numeric type. +template <> +inline constexpr bool is_unsigned = true; + +/// `true` if `T` is unsigned numeric type. +template <> +inline constexpr bool is_unsigned = true; + +/// `true` if `T` is signed numeric type. +template +inline constexpr bool is_signed = std::is_signed_v; + +/// `true` if `T` is signed numeric type. +template <> +inline constexpr bool is_signed = false; + +/// `true` if `T` is signed numeric type. +template <> +inline constexpr bool is_signed = false; + +/// `true` if `T` is integral numeric type. +template +inline constexpr bool is_integral = std::is_integral_v; + +/// `true` if `T` is integral numeric type. +template <> +inline constexpr bool is_integral = true; + +/// `true` if `T` is integral numeric type. +template <> +inline constexpr bool is_integral = true; + +/// `true` if `T` is floating-point type. +template +inline constexpr bool is_floating_point = std::is_floating_point_v; + +/// Signed version of type `T`. +template +struct make_signed { + using type = std::make_signed_t; +}; + +/// Signed version of type `T`. +template <> +struct make_signed { + using type = Int4; +}; + +/// Signed version of type `T`. +template <> +struct make_signed { + using type = Int4; +}; + +/// Signed version of type `T`. +template +using make_signed_t = typename make_signed::type; + +} // namespace kai::test diff --git a/test/reference/round.cpp b/test/reference/round.cpp new file mode 100644 index 00000000..52a2c557 --- /dev/null +++ b/test/reference/round.cpp @@ -0,0 +1,36 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/reference/round.hpp" + +#include +#include + +namespace kai::test { + +int32_t round_to_nearest_even_i32(float value) { + int32_t rounded = 0; + asm("fcvtns %w[output], %s[input]" : [output] "=r"(rounded) : [input] "w"(value)); + return rounded; +} + +size_t round_to_nearest_even_usize(float value) { + static_assert(sizeof(size_t) == sizeof(uint64_t)); + + uint64_t rounded = 0; + asm("fcvtns %x[output], %s[input]" : [output] "=r"(rounded) : [input] "w"(value)); + return rounded; +} + +size_t round_up_multiple(size_t a, size_t b) { + return ((a + b - 1) / b) * b; +} + +size_t round_down_multiple(size_t a, size_t b) { + return (a / b) * b; +} + +} // namespace kai::test diff --git a/test/reference/round.hpp b/test/reference/round.hpp new file mode 100644 index 00000000..7503936f --- /dev/null +++ b/test/reference/round.hpp @@ -0,0 +1,62 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +namespace kai::test { + +/// Rounds the specified value to nearest with tie to even. +/// +/// For example: +/// +/// * 0.4 is rounded to 0. +/// * 0.5 is rounded to 0 (as 0 is the nearest even value). +/// * 0.6 is rounded to 1. +/// * 1.4 is rounded to 1. +/// * 1.5 is rounded to 2 (as 2 is the nearest even value). +/// * 1.6 is rounded to 2. +/// +/// @param[in] value Value to be rounded. +/// +/// @return The rounded value. +int32_t round_to_nearest_even_i32(float value); + +/// Rounds the specified value to nearest with tie to even. +/// +/// For example: +/// +/// * 0.4 is rounded to 0. +/// * 0.5 is rounded to 0 (as 0 is the nearest even value). +/// * 0.6 is rounded to 1. +/// * 1.4 is rounded to 1. +/// * 1.5 is rounded to 2 (as 2 is the nearest even value). +/// * 1.6 is rounded to 2. +/// +/// @param[in] value Value to be rounded. +/// +/// @return The rounded value. +size_t round_to_nearest_even_usize(float value); + +/// Rounds up the input value to the multiple of the unit value. +/// +/// @param[in] a Input value. +/// @param[in] b Unit value. +/// +/// @return The rounded value. +size_t round_up_multiple(size_t a, size_t b); + +/// Rounds down the input value to the multiple of the unit value. +/// +/// @param[in] a Input value. +/// @param[in] b Unit value. +/// +/// @return The rounded value. +size_t round_down_multiple(size_t a, size_t b); + +} // namespace kai::test -- GitLab From 5be25ff33fb098b03abc3ef391c98f12b24cadcc Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Thu, 2 May 2024 10:11:11 +0100 Subject: [PATCH 2/3] Change data type QI8 to QAI8 Signed-off-by: Viet-Hoa Do --- test/common/data_type.hpp | 2 +- test/common/printer.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/common/data_type.hpp b/test/common/data_type.hpp index 8fbdcf9d..c707b53f 100644 --- a/test/common/data_type.hpp +++ b/test/common/data_type.hpp @@ -41,7 +41,7 @@ enum class DataType : uint16_t { I32 = 0b1'1'0'0'0000'00100000, ///< 32-bit signed integer. - QI8 = 0b1'1'1'1'0000'00001000, ///< 8-bit asymmetric quantized. + QAI8 = 0b1'1'1'1'0000'00001000, ///< 8-bit unsigned asymmetric quantized. QSU4 = 0b1'0'1'0'0000'00000100, ///< 4-bit unsigned symmetric quantized. QSI4 = 0b1'1'1'0'0000'00000100, ///< 4-bit signed symmetric quantized. diff --git a/test/common/printer.cpp b/test/common/printer.cpp index 368ae7db..9bc7a100 100644 --- a/test/common/printer.cpp +++ b/test/common/printer.cpp @@ -40,7 +40,7 @@ inline void print_data(std::ostream& os, const uint8_t* data, size_t len, DataTy os << reinterpret_cast(data)[i]; break; - case DataType::QI8: + case DataType::QAI8: os << static_cast(reinterpret_cast(data)[i]); break; -- GitLab From 6d2a4a7e85d661f9e9807e3ef9bbe9bb716131a9 Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Fri, 3 May 2024 09:29:53 +0100 Subject: [PATCH 3/3] Fix documentation Signed-off-by: Viet-Hoa Do --- test/common/data_type.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/common/data_type.hpp b/test/common/data_type.hpp index c707b53f..1144e4d5 100644 --- a/test/common/data_type.hpp +++ b/test/common/data_type.hpp @@ -41,7 +41,7 @@ enum class DataType : uint16_t { I32 = 0b1'1'0'0'0000'00100000, ///< 32-bit signed integer. - QAI8 = 0b1'1'1'1'0000'00001000, ///< 8-bit unsigned asymmetric quantized. + QAI8 = 0b1'1'1'1'0000'00001000, ///< 8-bit signed asymmetric quantized. QSU4 = 0b1'0'1'0'0000'00000100, ///< 4-bit unsigned symmetric quantized. QSI4 = 0b1'1'1'0'0000'00000100, ///< 4-bit signed symmetric quantized. -- GitLab