From b1cd4bee222eb3b80d095209ca66ddcbca85b8ec Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Tue, 16 Apr 2024 16:32:42 +0100 Subject: [PATCH 1/3] Create test project and add the first test * Create the test project consisting of: - Common concepts and utlities. - Reference operations. - Test cases. * Add the test for LHS 8-bit per-row quantized packing kernel: - This test shares the same test fixture with other tests for kernels involving in performing matrix multiplication. - The kernel is required to compute the whole output matrix or an arbitrary portion of it correctly. - The matmul shape and the output portion are test parameters that are defined in the list of test instances. Signed-off-by: Viet-Hoa Do --- .clang-format | 4 +- CMakeLists.txt | 36 ++- src/kai_common.h | 27 ++- src/matmul/kai_lhs_quant_pack_qa8dxP4X8_f32.c | 6 +- src/matmul/kai_lhs_quant_pack_qa8dxP4X8_f32.h | 2 +- test/common/compare.cpp | 227 ++++++++++++++++++ test/common/compare.hpp | 125 ++++++++++ test/common/data_format.cpp | 126 ++++++++++ test/common/data_format.hpp | 112 +++++++++ test/common/data_type.cpp | 81 +++++++ test/common/data_type.hpp | 107 +++++++++ test/common/logging.hpp | 33 +++ test/common/matrix_portion.cpp | 48 ++++ test/common/matrix_portion.hpp | 53 ++++ test/common/printer.cpp | 99 ++++++++ test/common/printer.hpp | 28 +++ test/reference/fill.cpp | 66 +++++ test/reference/fill.hpp | 27 +++ test/reference/pack.cpp | 114 +++++++++ test/reference/pack.hpp | 27 +++ test/reference/quantize.cpp | 48 ++++ test/reference/quantize.hpp | 33 +++ test/reference/round.cpp | 36 +++ test/reference/round.hpp | 62 +++++ test/sample.cpp | 10 - test/tests/matmul_test.cpp | 190 +++++++++++++++ 26 files changed, 1705 insertions(+), 22 deletions(-) create mode 100644 test/common/compare.cpp create mode 100644 test/common/compare.hpp create mode 100644 test/common/data_format.cpp create mode 100644 test/common/data_format.hpp create mode 100644 test/common/data_type.cpp create mode 100644 test/common/data_type.hpp create mode 100644 test/common/logging.hpp create mode 100644 test/common/matrix_portion.cpp create mode 100644 test/common/matrix_portion.hpp create mode 100644 test/common/printer.cpp create mode 100644 test/common/printer.hpp create mode 100644 test/reference/fill.cpp create mode 100644 test/reference/fill.hpp create mode 100644 test/reference/pack.cpp create mode 100644 test/reference/pack.hpp create mode 100644 test/reference/quantize.cpp create mode 100644 test/reference/quantize.hpp create mode 100644 test/reference/round.cpp create mode 100644 test/reference/round.hpp delete mode 100644 test/sample.cpp create mode 100644 test/tests/matmul_test.cpp diff --git a/.clang-format b/.clang-format index 77266d27..ad913784 100644 --- a/.clang-format +++ b/.clang-format @@ -5,12 +5,14 @@ # --- Language: Cpp -BasedOnStyle: LLVM +BasedOnStyle: Google ColumnLimit: 120 AccessModifierOffset: -4 AlignAfterOpenBracket: AlwaysBreak +AlignOperands: DontAlign AllowShortFunctionsOnASingleLine: None +BreakConstructorInitializers: AfterColon DerivePointerAlignment: false IndentWidth: 4 PointerAlignment: Left diff --git a/CMakeLists.txt b/CMakeLists.txt index d45d5070..76e07cfc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -48,13 +48,45 @@ set(KLEIDIAI_WARNING_FLAGS "-Wswitch-default" ) +add_library(kleidiai + src/matmul/kai_lhs_quant_pack_qa8dxP1X8_f32.c + src/matmul/kai_lhs_quant_pack_qa8dxP4X8_f32.c + src/matmul/kai_rhs_pack_nxk_qs4cxP_qs4cxS1S0.c + src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod.c + src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm.c +) + +target_compile_options(kleidiai + PRIVATE ${KLEIDIAI_WARNING_FLAGS} +) + if(KLEIDIAI_BUILD_TESTS) enable_testing() include(GoogleTest) - add_executable(kleidiai_test test/sample.cpp) + add_executable(kleidiai_test + test/common/data_type.cpp + test/common/data_format.cpp + test/common/printer.cpp + test/common/matrix_portion.cpp + test/common/compare.cpp + test/reference/fill.cpp + test/reference/pack.cpp + test/reference/quantize.cpp + test/reference/round.cpp + test/tests/matmul_test.cpp + ) + + target_include_directories(kleidiai_test + PRIVATE . + ) + target_compile_options(kleidiai_test PRIVATE ${KLEIDIAI_WARNING_FLAGS}) - target_link_libraries(kleidiai_test PRIVATE GTest::gtest_main) + + target_link_libraries(kleidiai_test + PRIVATE kleidiai + PRIVATE GTest::gtest_main + ) # Cross-compiling is a common use case which creates a conflict if DISCOVERY_MODE is set to POST_BUILD (by default) # since the host platform does not match the target. Setting the mode to PRE_TEST avoids this conflict. diff --git a/src/kai_common.h b/src/kai_common.h index bed271c9..47462b4e 100644 --- a/src/kai_common.h +++ b/src/kai_common.h @@ -9,15 +9,32 @@ extern "C" { #endif +#include #include -#define KAI_ASSERT(x) \ - do { \ - if (!(x)) { \ - exit(EXIT_FAILURE); \ - } \ +#define KAI_ERROR(msg) \ + do { \ + fprintf(stderr, "%s", msg); \ + exit(EXIT_FAILURE); \ } while (0) +#define KAI_ASSERT_MSG(cond, msg) \ + do { \ + if (!(cond)) { \ + KAI_ERROR(msg); \ + } \ + } while (0) + +#define KAI_ASSERT(cond) KAI_ASSERT_MSG(cond, #cond) + +#define KAI_ASSERT_IF_MSG(precond, cond, msg) KAI_ASSERT_MSG(!(precond) || (cond), msg) +#define KAI_ASSERT_IF(precond, cond) KAI_ASSERT_IF_MSG(precond, cond, #precond " |-> " #cond) + +#define KAI_ASSUME_MSG KAI_ASSERT_MSG +#define KAI_ASSUME KAI_ASSERT +#define KAI_ASSUME_IF_MSG KAI_ASSERT_IF_MSG +#define KAI_ASSUME_IF KAI_ASSERT_IF + #define KAI_UNUSED(x) (void)(x) #define KAI_MIN(a, b) (((a) < (b)) ? (a) : (b)) #define KAI_MAX(a, b) (((a) > (b)) ? (a) : (b)) diff --git a/src/matmul/kai_lhs_quant_pack_qa8dxP4X8_f32.c b/src/matmul/kai_lhs_quant_pack_qa8dxP4X8_f32.c index 89fc26d9..166000c0 100644 --- a/src/matmul/kai_lhs_quant_pack_qa8dxP4X8_f32.c +++ b/src/matmul/kai_lhs_quant_pack_qa8dxP4X8_f32.c @@ -5,13 +5,13 @@ // #include "kai_lhs_quant_pack_qa8dxP4X8_f32.h" -#include "../kai_common.h" - #include #include #include #include +#include "../kai_common.h" + static const size_t kai_kk0 = 8; static const size_t kai_km0 = 4; static const size_t kai_num_bytes_per_multiplier = sizeof(float); @@ -35,7 +35,7 @@ size_t kai_get_lhs_packed_size_lhs_quant_pack_qa8dxP4X8_f32(size_t m, size_t k) } void kai_run_lhs_quant_pack_qa8dxP4X8_f32( - size_t m, size_t k, const float* restrict lhs, size_t lhs_stride, void* restrict lhs_p) { + size_t m, size_t k, const void* restrict lhs, size_t lhs_stride, void* restrict lhs_p) { KAI_ASSERT(k % kai_kk0 == 0); KAI_ASSERT(m <= 3 || (m % kai_km0 == 0)); diff --git a/src/matmul/kai_lhs_quant_pack_qa8dxP4X8_f32.h b/src/matmul/kai_lhs_quant_pack_qa8dxP4X8_f32.h index 11bb931e..6309a9d8 100644 --- a/src/matmul/kai_lhs_quant_pack_qa8dxP4X8_f32.h +++ b/src/matmul/kai_lhs_quant_pack_qa8dxP4X8_f32.h @@ -58,7 +58,7 @@ size_t kai_get_lhs_packed_size_lhs_quant_pack_qa8dxP4X8_f32(size_t m, size_t k); * @param[in] lhs_stride Stride in bytes between two rows of LHS. * @param[in] lhs_p The quantized and packed LHS matrix. */ -void kai_run_lhs_quant_pack_qa8dxP4X8_f32(size_t m, size_t k, const float* lhs, size_t lhs_stride, void* lhs_p); +void kai_run_lhs_quant_pack_qa8dxP4X8_f32(size_t m, size_t k, const void* lhs, size_t lhs_stride, void* lhs_p); #ifdef __cplusplus } diff --git a/test/common/compare.cpp b/test/common/compare.cpp new file mode 100644 index 00000000..e9054b51 --- /dev/null +++ b/test/common/compare.cpp @@ -0,0 +1,227 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/common/compare.hpp" + +#include +#include +#include +#include + +#include "src/kai_common.h" +#include "test/common/data_format.hpp" +#include "test/common/logging.hpp" +#include "test/common/matrix_portion.hpp" +#include "test/common/printer.hpp" + +namespace kai::test { + +namespace { + +/// Calculates the absolute and relative errors. +/// +/// @param[in] imp Value under test. +/// @param[in] ref Reference value. +/// +/// @return The absolute error and relative error. +template +std::tuple calculate_error(T imp, T ref) { + const float imp_f = imp; + const float ref_f = ref; + + const auto abs_error = std::abs(imp_f - ref_f); + const auto rel_error = ref_f != 0 ? abs_error / std::abs(ref_f) : 0.0f; + + return {abs_error, rel_error}; +} + +/// Compares matrices with per-row quantization. +template +bool compare_per_row( + const void* imp_data, const void* ref_data, const DataFormat& format, size_t full_height, size_t full_width, + const MatrixPortion& portion, MismatchHandler& handler) { + const auto block_height = format.block_height(); + const auto block_width = format.block_width(); + + const auto [rect_y, rect_x, rect_height, rect_width] = + portion.compute_portion(full_height, full_width, block_height, full_width); + KAI_ASSERT(format.scheduler_block_height(full_height) == block_height); + KAI_ASSERT(format.scheduler_block_width(full_width) == full_width); + KAI_ASSERT(rect_x == 0); + KAI_ASSERT(rect_width == full_width); + + const auto num_groups = (full_height + block_height - 1) / block_height; + const auto group_num_blocks = (full_width + block_width - 1) / block_width; + + const auto group_offsets_bytes = block_height * sizeof(Offset); + const auto group_scales_bytes = block_height * sizeof(Scale); + const auto block_data_bytes = block_height * block_width * sizeof(Data); + + const auto begin_group = rect_y / block_height; + const auto end_group = (rect_y + rect_height) / block_height; + + const auto* imp_ptr = reinterpret_cast(imp_data); + const auto* ref_ptr = reinterpret_cast(ref_data); + + for (size_t group_no = 0; group_no < num_groups; ++group_no) { + const auto in_roi = group_no >= begin_group && group_no < end_group; + + // Checks the quantization offsets. + for (size_t i = 0; i < block_height; ++i) { + const auto imp_offset = reinterpret_cast(imp_ptr)[i]; + const Offset ref_offset = in_roi ? reinterpret_cast(ref_ptr)[i] : 0; + const auto [abs_err, rel_err] = calculate_error(imp_offset, ref_offset); + + if (abs_err != 0 || rel_err != 0) { + handler.mark_as_failed(); + + const auto raw_row = group_no * block_height + i; + KAI_LOGE( + "Mismatched quantization offset ", raw_row, ": actual = ", imp_offset, ", expected: ", ref_offset); + } + } + + imp_ptr += group_offsets_bytes; + ref_ptr += group_offsets_bytes; + + // Checks the data. + for (size_t block_no = 0; block_no < group_num_blocks; ++block_no) { + for (size_t y = 0; y < block_height; ++y) { + for (size_t x = 0; x < block_width; ++x) { + const auto imp_data = reinterpret_cast(imp_ptr)[y * block_width + x]; + const Data ref_data = in_roi ? reinterpret_cast(ref_ptr)[y * block_width + x] : 0; + const auto [abs_err, rel_err] = calculate_error(imp_data, ref_data); + + if (abs_err != 0 || rel_err != 0) { + const auto notifying = !in_roi || handler.handle_data(abs_err, rel_err); + + if (notifying) { + const auto raw_row = group_no * block_height + y; + const auto raw_col = block_no * block_width + x; + + KAI_LOGE( + "Mismatched data at (", raw_row, ", ", raw_col, "): actual = ", imp_data, + ", expected: ", ref_data); + } + } + } + } + + imp_ptr += block_data_bytes; + ref_ptr += block_data_bytes; + } + + // Checks the quantization scales. + for (size_t i = 0; i < block_height; ++i) { + const auto imp_scale = reinterpret_cast(imp_ptr)[i]; + const Scale ref_scale = in_roi ? reinterpret_cast(ref_ptr)[i] : 0; + const auto [abs_err, rel_err] = calculate_error(imp_scale, ref_scale); + + if (abs_err != 0 || rel_err != 0) { + handler.mark_as_failed(); + + const auto raw_row = group_no * block_height + i; + KAI_LOGE( + "Mismatched quantization scale ", raw_row, ": actual = ", imp_scale, ", expected: ", ref_scale); + } + } + + imp_ptr += group_scales_bytes; + ref_ptr += group_scales_bytes; + } + + return handler.success(rect_height * full_width); +} + +} // namespace + +bool compare( + const void* imp_data, const void* ref_data, const DataFormat& format, size_t full_height, size_t full_width, + const MatrixPortion& portion, MismatchHandler&& handler) { + const auto data_type = format.data_type(); + const auto scale_dt = format.scale_data_type(); + const auto offset_dt = format.offset_data_type(); + + switch (format.quantization_format()) { + case DataFormat::QuantizationFormat::PER_ROW: + if (data_type == DataType::QI8 && scale_dt == DataType::FP32 && offset_dt == DataType::I32) { + return compare_per_row( + imp_data, ref_data, format, full_height, full_width, portion, handler); + } + + break; + + default: + break; + } + + KAI_ERROR("Unsupported format!"); +} + +// ===================================================================================================================== + +DefaultMismatchHandler::DefaultMismatchHandler( + float abs_error_threshold, float rel_error_threshold, size_t abs_mismatched_threshold, + float rel_mismatched_threshold) : + _abs_error_threshold(abs_error_threshold), + _rel_error_threshold(rel_error_threshold), + _abs_mismatched_threshold(abs_mismatched_threshold), + _rel_mismatched_threshold(rel_mismatched_threshold), + _num_mismatches(0), + _failed(false) { +} + +DefaultMismatchHandler::DefaultMismatchHandler(const DefaultMismatchHandler& rhs) : + _abs_error_threshold(rhs._abs_error_threshold), + _rel_error_threshold(rhs._rel_error_threshold), + _abs_mismatched_threshold(rhs._abs_mismatched_threshold), + _rel_mismatched_threshold(rhs._rel_mismatched_threshold), + _num_mismatches(0), + _failed(false) { + // Cannot copy mismatch handler that is already in use. + KAI_ASSERT(rhs._num_mismatches == 0); + KAI_ASSERT(!rhs._failed); +} + +DefaultMismatchHandler& DefaultMismatchHandler::operator=(const DefaultMismatchHandler& rhs) { + if (this != &rhs) { + // Cannot copy mismatch handler that is already in use. + KAI_ASSERT(rhs._num_mismatches == 0); + KAI_ASSERT(!rhs._failed); + + _abs_error_threshold = rhs._abs_error_threshold; + _rel_error_threshold = rhs._rel_error_threshold; + _abs_mismatched_threshold = rhs._abs_mismatched_threshold; + _rel_mismatched_threshold = rhs._rel_mismatched_threshold; + } + + return *this; +} + +bool DefaultMismatchHandler::handle_data(float absolute_error, float relative_error) { + const auto mismatched = absolute_error > _abs_error_threshold && relative_error > _rel_error_threshold; + + if (mismatched) { + ++_num_mismatches; + } + + return mismatched; +} + +void DefaultMismatchHandler::mark_as_failed() { + _failed = true; +} + +bool DefaultMismatchHandler::success(size_t num_checks) const { + if (_failed) { + return false; + } + + float mismatched_rate = static_cast(_num_mismatches) / static_cast(num_checks); + return _num_mismatches <= _abs_mismatched_threshold && mismatched_rate <= _rel_mismatched_threshold; +} + +} // namespace kai::test diff --git a/test/common/compare.hpp b/test/common/compare.hpp new file mode 100644 index 00000000..04ac829e --- /dev/null +++ b/test/common/compare.hpp @@ -0,0 +1,125 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +namespace kai::test { + +class DataFormat; +class MatrixPortion; +class MismatchHandler; + +/// Compares two matrices to check whether they are matched. +/// +/// @param[in] imp_data Data buffer of the actual implementation matrix. +/// @param[in] ref_data Data buffer of the reference implementation matrix. +/// @param[in] format Data format. +/// @param[in] full_height Height of the full matrix. +/// @param[in] full_width Width of the full matrix. +/// @param[in] portion Portion of the matrix to be calculated in the actual implementation matrix. +/// @param[in] handler Mismatch handler. +/// +/// @return `true` if the two matrices are considered matched. +bool compare( + const void* imp_data, const void* ref_data, const DataFormat& format, size_t full_height, size_t full_width, + const MatrixPortion& portion, MismatchHandler&& handler); + +/// Handles mismatches found by @ref validate function. +class MismatchHandler { +public: + /// Constructor. + MismatchHandler() = default; + + /// Destructor. + virtual ~MismatchHandler() = default; + + /// Copy constructor. + MismatchHandler(const MismatchHandler&) = default; + + /// Copy assignment. + MismatchHandler& operator=(const MismatchHandler&) = default; + + /// Move constructor. + MismatchHandler(MismatchHandler&&) noexcept = default; + + /// Move assignment. + MismatchHandler& operator=(MismatchHandler&&) noexcept = default; + + /// Handles new mismatch result. + /// + /// This method must be called even when no error is detected. + /// + /// @param[in] absolute_error Absolute error. + /// @param[in] relative_error Relative error. + /// + /// @return `true` if the mismatch is sufficiently large to be logged as real mismatch. + virtual bool handle_data(float absolute_error, float relative_error) = 0; + + /// Marks the result as failed. + /// + /// It is zero tolerance if the data point is considered impossible to have mismatch + /// regardless of implementation method. + /// These normally include data point outside if the portion of interest (these must be 0) + /// and data point belongs to quantization information. + virtual void mark_as_failed() = 0; + + /// Returns a value indicating whether the two matrices are considered matched. + /// + /// @param[in] num_checks Total number of data points that have been checked. + /// + /// @return `true` if the two matrices are considered matched. + [[nodiscard]] virtual bool success(size_t num_checks) const = 0; +}; + +/// This mismatch handler considers two values being mismatched when both the relative error +/// and the absolute error exceed their respective thresholds. +/// +/// This mismatch handler considers two matrices being mismatched when the number of mismatches +/// exceed both the relative and absolute thresholds. +class DefaultMismatchHandler final : public MismatchHandler { +public: + /// Creates a new mismatch handler. + /// + /// @param[in] abs_error_threshold Threshold for absolute error + /// @param[in] rel_error_threshold Threshold for relative error. + /// @param[in] abs_mismatched_threshold Threshold for the number of mismatched data points. + /// @param[in] rel_mismatched_threshold Threshold for the ratio of mismatched data points. + DefaultMismatchHandler( + float abs_error_threshold, float rel_error_threshold, size_t abs_mismatched_threshold, + float rel_mismatched_threshold); + + /// Destructur. + ~DefaultMismatchHandler() = default; + + /// Copy constructor. + DefaultMismatchHandler(const DefaultMismatchHandler& rhs); + + /// Copy assignment. + DefaultMismatchHandler& operator=(const DefaultMismatchHandler& rhs); + + /// Move constructor. + DefaultMismatchHandler(DefaultMismatchHandler&& rhs) noexcept = default; + + /// Move assignment. + DefaultMismatchHandler& operator=(DefaultMismatchHandler&& rhs) noexcept = default; + + bool handle_data(float absolute_error, float relative_error) override; + void mark_as_failed() override; + [[nodiscard]] bool success(size_t num_checks) const override; + +private: + float _abs_error_threshold; + float _rel_error_threshold; + size_t _abs_mismatched_threshold; + float _rel_mismatched_threshold; + + size_t _num_mismatches; + bool _failed; +}; + +} // namespace kai::test diff --git a/test/common/data_format.cpp b/test/common/data_format.cpp new file mode 100644 index 00000000..35fddaaa --- /dev/null +++ b/test/common/data_format.cpp @@ -0,0 +1,126 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/common/data_format.hpp" + +#include + +#include "src/kai_common.h" +#include "test/common/data_type.hpp" +#include "test/reference/round.hpp" + +namespace kai::test { + +DataFormat::DataFormat( + DataType data_type, size_t block_height, size_t block_width, QuantizationFormat quant_format, DataType scale_dt, + DataType offset_dt) noexcept : + _data_type(data_type), + _quant_format(quant_format), + _scale_dt(scale_dt), + _offset_dt(offset_dt), + _block_height(block_height), + _block_width(block_width) { +} + +bool DataFormat::operator==(const DataFormat& rhs) const { + return _data_type == rhs._data_type && _quant_format == rhs._quant_format && _scale_dt == rhs._scale_dt && + _offset_dt == rhs._offset_dt && _block_height == rhs._block_height && _block_width == rhs._block_width; +} + +bool DataFormat::operator!=(const DataFormat& rhs) const { + return !(*this == rhs); +} + +DataType DataFormat::data_type() const { + return _data_type; +} + +DataFormat::QuantizationFormat DataFormat::quantization_format() const { + return _quant_format; +} + +DataType DataFormat::scale_data_type() const { + return _scale_dt; +} + +DataType DataFormat::offset_data_type() const { + return _offset_dt; +} + +size_t DataFormat::block_height() const { + return _block_height; +} + +size_t DataFormat::block_width() const { + return _block_width; +} + +size_t DataFormat::scheduler_block_height([[maybe_unused]] size_t full_height) const { + switch (_quant_format) { + case QuantizationFormat::NONE: + case QuantizationFormat::PER_ROW: + return _block_height; + + default: + KAI_ERROR("Unsupported quantization packing format!"); + } +} + +size_t DataFormat::scheduler_block_width(size_t full_width) const { + switch (_quant_format) { + case QuantizationFormat::NONE: + return _block_width; + + case QuantizationFormat::PER_ROW: + return full_width; + + default: + KAI_ERROR("Unsupported quantization packing format!"); + } +} + +uintptr_t DataFormat::default_row_stride(size_t width) const { + const auto padded_width = round_up_multiple(width, _block_width); + + switch (_quant_format) { + case QuantizationFormat::NONE: + return padded_width * data_type_size_in_bits(_data_type) / 8; + + case QuantizationFormat::PER_ROW: + return _block_height * data_type_size_in_bits(_offset_dt) / 8 + // + _block_height * padded_width * data_type_size_in_bits(_data_type) / 8 + // + _block_height * data_type_size_in_bits(_scale_dt) / 8; + + default: + KAI_ERROR("Unsupported quantization packing format!"); + } +} + +uintptr_t DataFormat::default_offset_in_bytes(size_t row, size_t col, size_t width) const { + const auto row_stride = default_row_stride(width); + + KAI_ASSERT(col % scheduler_block_width(width) == 0); + + switch (_quant_format) { + case QuantizationFormat::NONE: + return row * row_stride + col; + + case QuantizationFormat::PER_ROW: + KAI_ASSERT(row % _block_height == 0); + KAI_ASSERT(col == 0); + return (row / _block_height) * row_stride + col; + + default: + KAI_ERROR("Unsupported quantization packing format!"); + } +} + +size_t DataFormat::default_size_in_bytes(size_t height, size_t width) const { + const auto num_rows = (height + _block_height - 1) / _block_height; + return num_rows * default_row_stride(width); +} + +} // namespace kai::test diff --git a/test/common/data_format.hpp b/test/common/data_format.hpp new file mode 100644 index 00000000..bd9cfb7e --- /dev/null +++ b/test/common/data_format.hpp @@ -0,0 +1,112 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#include "test/common/data_type.hpp" + +namespace kai::test { + +class DataFormat { +public: + /// Quantization packing format. + enum class QuantizationFormat : uint32_t { + NONE, ///< No quantization information is included. + PER_ROW, ///< Per-row quantization. + }; + + /// Creates a new data format. + /// + /// @param[in] data_type Data type of data value. + /// @param[in] quant_format Quantization packing format. + /// @param[in] scale_dt Data type of scale value. + /// @param[in] offset_dt Data type of offset value. + /// @param[in] block_height Block height. + /// @param[in] block_width Block width. + DataFormat( + DataType data_type, size_t block_height = 1, size_t block_width = 1, + QuantizationFormat quant_format = QuantizationFormat::NONE, DataType scale_dt = DataType::UNKNOWN, + DataType offset_dt = DataType::UNKNOWN) noexcept; + + /// Equality operator. + [[nodiscard]] bool operator==(const DataFormat& rhs) const; + + /// Unequality operator. + [[nodiscard]] bool operator!=(const DataFormat& rhs) const; + + /// Gets the quantization packing format. + [[nodiscard]] QuantizationFormat quantization_format() const; + + /// Gets the data type of data value. + [[nodiscard]] DataType data_type() const; + + /// Gets the data type of scale value. + [[nodiscard]] DataType scale_data_type() const; + + /// Gets the data type of offset value. + [[nodiscard]] DataType offset_data_type() const; + + /// Gets the block height. + [[nodiscard]] size_t block_height() const; + + /// Gets the block width. + [[nodiscard]] size_t block_width() const; + + /// Gets the scheduling block height. + /// + /// @param[in] full_height Height of the full matrix. + /// + /// @return The block height for scheduling purpose. + [[nodiscard]] size_t scheduler_block_height(size_t full_height) const; + + /// Gets the scheduling block width. + /// + /// @param[in] full_width Width of the full matrix. + /// + /// @return The block width for scheduling purpose. + [[nodiscard]] size_t scheduler_block_width(size_t full_width) const; + + /// Gets the row stride in bytes given the data is stored continuously without any gap in the memory. + /// + /// In case of per-row quantization, the row stride is the number of bytes from one row group + /// to the next. One row group consists of `block_height` rows. + /// + /// @param[in] width Width of the full matrix. + /// + /// @return The default row stride in bytes of the matrix. + [[nodiscard]] uintptr_t default_row_stride(size_t width) const; + + /// Gets the offsets in bytes in the data buffer given the data is stored continuously + /// without any gap in the memory. + /// + /// @param[in] row Row coordinate. + /// @param[in] col Colum coordinate. + /// @param[in] width Width of the full matrix. + /// + /// @return The default offset in bytes. + [[nodiscard]] uintptr_t default_offset_in_bytes(size_t row, size_t col, size_t width) const; + + /// Gets the size in bytes of the matrix given the data is stored continuously without any gap in the memory. + /// + /// @param[in] height Height of the full matrix. + /// @param[in] width Width of the full matrix. + /// + /// @return The size in bytes of the matrix. + [[nodiscard]] size_t default_size_in_bytes(size_t height, size_t width) const; + +private: + DataType _data_type; + QuantizationFormat _quant_format; + DataType _scale_dt; + DataType _offset_dt; + size_t _block_height; + size_t _block_width; +}; + +} // namespace kai::test diff --git a/test/common/data_type.cpp b/test/common/data_type.cpp new file mode 100644 index 00000000..2e7f7f38 --- /dev/null +++ b/test/common/data_type.cpp @@ -0,0 +1,81 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/common/data_type.hpp" + +#include +#include + +#include "src/kai_common.h" + +namespace kai::test { + +namespace { + +bool has_i(DataType dt) { + return static_cast(dt) & (1 << 15); +} + +bool has_s(DataType dt) { + return static_cast(dt) & (1 << 14); +} + +bool has_q(DataType dt) { + return static_cast(dt) & (1 << 13); +} + +bool has_a(DataType dt) { + return static_cast(dt) & (1 << 12); +} + +size_t bits(DataType dt) { + return static_cast(dt) & 0xFF; +} + +} // namespace + +size_t data_type_size_in_bits(DataType dt) { + return bits(dt); +} + +bool data_type_is_integral(DataType dt) { + return has_i(dt); +} + +bool data_type_is_float(DataType dt) { + KAI_ASSERT(data_type_is_signed(dt)); + return !data_type_is_integral(dt); +} + +bool data_type_is_float_fp(DataType dt) { + KAI_ASSERT(data_type_is_float(dt)); + return !has_q(dt); +} + +bool data_type_is_float_bf(DataType dt) { + KAI_ASSERT(data_type_is_float(dt)); + return has_q(dt); +} + +bool data_type_is_signed(DataType dt) { + if (!has_s(dt)) { + KAI_ASSERT(data_type_is_integral(dt)); + } + + return has_s(dt); +} + +bool data_type_is_quantized(DataType dt) { + KAI_ASSERT(data_type_is_integral(dt)); + return has_q(dt); +} + +bool data_type_is_quantized_asymm(DataType dt) { + KAI_ASSERT(data_type_is_quantized(dt)); + return has_a(dt); +} + +} // namespace kai::test diff --git a/test/common/data_type.hpp b/test/common/data_type.hpp new file mode 100644 index 00000000..17aeedab --- /dev/null +++ b/test/common/data_type.hpp @@ -0,0 +1,107 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +namespace kai::test { + +/// Data type. +enum class DataType : uint16_t { + // Encoding: + // + // 15 0 + // +---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+ + // | i | s | q | a | RES0 | bits | + // +---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+ + // + // (RES0: reserved, filled with 0s) + // + // Fields: + // + // * i: integer (1) or floating-point (0). + // * s: signed (1) or unsigned (0). + // * q: + // - Integer (i): quantized (1) or non-quantized (0). + // - Floating-point (!i): brain (1) or binary (0). + // * a: + // - Quantized (i && q): asymmetric (1) or symmetric (0). + // - Otherwise: RES0. + // * bits: size in bits. + + UNKNOWN = 0, ///< No data. + + FP32 = 0b0'1'0'0'0000'00100000, ///< Single-precision floating-point. + FP16 = 0b0'1'0'0'0000'00010000, ///< Half-precision floating-point. + + I32 = 0b1'1'0'0'0000'00100000, ///< 32-bit signed integer. + + QI8 = 0b1'1'1'1'0000'00001000, ///< 8-bit asymmetric quantized. +}; + +/// Gets the size in bits of the specified data type. +/// +/// @param[in] dt The data type. +/// +/// @return The size in bits. +[[nodiscard]] size_t data_type_size_in_bits(DataType dt); + +/// Gets a value indicating whether the data type is integral. +/// +/// @param[in] dt The data type. +/// +/// @return `true` if the data type is integral. +[[nodiscard]] bool data_type_is_integral(DataType dt); + +/// Gets a value indicating whether the data type is floating-point. +/// +/// @param[in] dt The data type. +/// +/// @return `true` if the data type is floating-point. +[[nodiscard]] bool data_type_is_float(DataType dt); + +/// Gets a value indicating whether the data type is binary floating-point. +/// +/// Binary floating point are `half`, `float`, `double`. +/// +/// @param[in] dt The data type. +/// +/// @return `true` if the data type is binary floating-point. +[[nodiscard]] bool data_type_is_float_fp(DataType dt); + +/// Gets a value indicating whether the data type is brain floating-point. +/// +/// Binary floating point are `bfloat16`. +/// +/// @param[in] dt The data type. +/// +/// @return `true` if the data type is brain floating-point. +[[nodiscard]] bool data_type_is_float_bf(DataType dt); + +/// Gets a value indicating whether the data type is signed. +/// +/// @param[in] dt The data type. +/// +/// @return `true` if the data type is signed. +[[nodiscard]] bool data_type_is_signed(DataType dt); + +/// Gets a value indicating whether the data type is quantized. +/// +/// @param[in] dt The data type. +/// +/// @return `true` if the data type is quantized. +[[nodiscard]] bool data_type_is_quantized(DataType dt); + +/// Gets a value indicating whether the data type is asymmetric quantized. +/// +/// @param[in] dt The data type. +/// +/// @return `true` if the data type is asymmetric quantized. +[[nodiscard]] bool data_type_is_quantized_asymm(DataType dt); + +} // namespace kai::test diff --git a/test/common/logging.hpp b/test/common/logging.hpp new file mode 100644 index 00000000..85902bc6 --- /dev/null +++ b/test/common/logging.hpp @@ -0,0 +1,33 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#define KAI_LOGE(...) kai::test::detail::log("DEBUG", __VA_ARGS__) + +namespace kai::test::detail { + +template +void write_log_content(T&& value) { + std::cerr << value; +} + +template +void write_log_content(T&& value, Ts&&... others) { + write_log_content(std::forward(value)); + write_log_content(std::forward(others)...); +} + +template +void log(std::string_view level, Ts&&... args) { + std::cerr << level << " | "; + write_log_content(std::forward(args)...); +} + +} // namespace kai::test::detail diff --git a/test/common/matrix_portion.cpp b/test/common/matrix_portion.cpp new file mode 100644 index 00000000..d886af34 --- /dev/null +++ b/test/common/matrix_portion.cpp @@ -0,0 +1,48 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/common/matrix_portion.hpp" + +#include +#include +#include + +#include "test/reference/round.hpp" + +namespace kai::test { + +MatrixPortion::MatrixPortion(float start_row, float start_col, float height, float width) : + _start_row(start_row), _start_col(start_col), _height(height), _width(width) { +} + +std::tuple MatrixPortion::compute_portion( + size_t full_height, size_t full_width, size_t block_height, size_t block_width) const { + const auto start_row_f = std::clamp(_start_row, 0, 1); + const auto start_col_f = std::clamp(_start_col, 0, 1); + const auto height_f = std::clamp(_height, 0, 1); + const auto width_f = std::clamp(_width, 0, 1); + + auto start_row = round_to_nearest_even_usize(start_row_f * full_height); + auto start_col = round_to_nearest_even_usize(start_col_f * full_width); + auto height = round_to_nearest_even_usize(height_f * full_height); + auto width = round_to_nearest_even_usize(width_f * full_width); + + start_row = round_down_multiple(start_row, block_height); + start_col = round_down_multiple(start_col, block_width); + + start_row = std::min(start_row, round_down_multiple(full_height, block_height)); + start_col = std::min(start_col, round_down_multiple(full_width, block_width)); + + height = round_up_multiple(height, block_height); + width = round_up_multiple(width, block_width); + + height = std::min(height, full_height - start_row); + width = std::min(width, full_width - start_col); + + return {start_row, start_col, height, width}; +} + +} // namespace kai::test diff --git a/test/common/matrix_portion.hpp b/test/common/matrix_portion.hpp new file mode 100644 index 00000000..e0a37dcd --- /dev/null +++ b/test/common/matrix_portion.hpp @@ -0,0 +1,53 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +namespace kai::test { + +/// Portion of a matrix. +/// +/// This class is used to define the sub-matrix a test is running and checking. +class MatrixPortion { +public: + /// Creates a new matrix portion. + /// + /// @param[in] start_row Starting row as the ratio to the height of the matrix. + /// @param[in] start_col Starting column as the ratio to the width of the matrix. + /// @param[in] height Portion height as the ratio to the height of the matrix. + /// @param[in] width Portion width as the ratio to the width of the matrix. + MatrixPortion(float start_row, float start_col, float height, float width); + + /// Computes the starting coordinate and the shape of the sub-matrix. + /// + /// Requirements: + /// + /// * The starting coordinate of the sub-matrix shall be aligned with the block boundary. + /// * If it is not the block at the right and/or bottom edge of the full matrix, the height and width + /// of the sub-matrix shall be rounded up to multiple of the block height and width. + /// * If it is the block at the right and/or bottom edge of the full matrix, the height and width + /// of the sub-matrix shall be the rounded up to the edge of the matrix. + /// + /// @param[in] full_height Matrix height. + /// @param[in] full_width Matrix width. + /// @param[in] block_height Block height. + /// @param[in] block_width Block width. + /// + /// @return The starting row, starting column, height and width of the sub-matrix. + [[nodiscard]] std::tuple compute_portion( + size_t full_height, size_t full_width, size_t block_height, size_t block_width) const; + +private: + float _start_row; + float _start_col; + float _height; + float _width; +}; + +} // namespace kai::test diff --git a/test/common/printer.cpp b/test/common/printer.cpp new file mode 100644 index 00000000..25386553 --- /dev/null +++ b/test/common/printer.cpp @@ -0,0 +1,99 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include + +#include "src/kai_common.h" +#include "test/common/data_format.hpp" +#include "test/common/data_type.hpp" + +namespace kai::test { + +namespace { + +inline void print_data(std::ostream& os, const uint8_t* data, size_t len, DataType data_type) { + for (size_t i = 0; i < len; ++i) { + switch (data_type) { + case DataType::FP32: + os << reinterpret_cast(data)[i]; + break; + + case DataType::I32: + os << reinterpret_cast(data)[i]; + break; + + case DataType::QI8: + os << static_cast(reinterpret_cast(data)[i]); + break; + + default: + KAI_ERROR("Unsupported data type!"); + } + + os << ", "; + } +} + +void print_matrix_raw(std::ostream& os, const uint8_t* data, DataType data_type, size_t height, size_t width) { + const auto row_stride = width * data_type_size_in_bits(data_type) / 8; + + os << "[\n"; + for (size_t y = 0; y < height; ++y) { + os << " ["; + print_data(os, data + y * row_stride, width, data_type); + os << "]\n"; + } + os << "]\n"; +} + +void print_matrix_per_row( + std::ostream& os, const uint8_t* data, const DataFormat& format, size_t height, size_t width) { + const auto block_height = format.block_height(); + const auto num_blocks = (height + block_height - 1) / block_height; + + const auto block_data_bytes = block_height * width * data_type_size_in_bits(format.data_type()) / 8; + const auto block_offsets_bytes = block_height * data_type_size_in_bits(format.offset_data_type()) / 8; + const auto block_scales_bytes = block_height * data_type_size_in_bits(format.scale_data_type()) / 8; + + os << "[\n"; + for (size_t y = 0; y < num_blocks; ++y) { + os << " {\"offsets\": ["; + print_data(os, data, block_height, format.offset_data_type()); + os << "], \"data\": ["; + print_data(os, data + block_offsets_bytes, block_height * width, format.data_type()); + os << "], \"scales\": ["; + print_data(os, data + block_offsets_bytes + block_data_bytes, block_height, format.scale_data_type()); + os << "]},\n"; + + data += block_offsets_bytes + block_data_bytes + block_scales_bytes; + } + os << "]\n"; +} + +} // namespace + +void print_matrix( + std::ostream& os, std::string_view name, const void* data, const DataFormat& format, size_t height, size_t width) { + os << name << " = "; + + switch (format.quantization_format()) { + case DataFormat::QuantizationFormat::NONE: + print_matrix_raw(os, reinterpret_cast(data), format.data_type(), height, width); + break; + + case DataFormat::QuantizationFormat::PER_ROW: + print_matrix_per_row(os, reinterpret_cast(data), format, height, width); + break; + + default: + KAI_ERROR("Unsupported quantization packing format!"); + } +} + +} // namespace kai::test diff --git a/test/common/printer.hpp b/test/common/printer.hpp new file mode 100644 index 00000000..ec6213ae --- /dev/null +++ b/test/common/printer.hpp @@ -0,0 +1,28 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include + +namespace kai::test { + +class DataFormat; + +/// Prints the matrix data to the output stream. +/// +/// @param[in] os Output stream to write the data to. +/// @param[in] name Matrix name. +/// @param[in] data Data buffer. +/// @param[in] format Data format. +/// @param[in] height Number of rows. +/// @param[in] width Number of columns. +void print_matrix( + std::ostream& os, std::string_view name, const void* data, const DataFormat& format, size_t height, size_t width); + +} // namespace kai::test diff --git a/test/reference/fill.cpp b/test/reference/fill.cpp new file mode 100644 index 00000000..4d7e3267 --- /dev/null +++ b/test/reference/fill.cpp @@ -0,0 +1,66 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/reference/fill.hpp" + +#include +#include +#include +#include +#include +#include + +#include "src/kai_common.h" +#include "test/common/data_format.hpp" + +namespace kai::test { + +template +std::vector fill_matrix_raw(size_t height, size_t width, std::function gen) { + const auto size = height * width * sizeof(T); + std::vector data; + data.resize(size); + auto ptr = reinterpret_cast(data.data()); + + for (size_t y = 0; y < height; ++y) { + for (size_t x = 0; x < width; ++x) { + ptr[y * width + x] = gen(y, x); + } + } + + return data; +} + +template +std::vector fill_matrix_random_raw(size_t height, size_t width, [[maybe_unused]] uint64_t seed) { + using TDist = std::conditional_t< + std::is_floating_point_v, std::uniform_real_distribution, std::uniform_int_distribution>; + + std::mt19937 rnd(seed); + TDist dist; + + return fill_matrix_raw(height, width, [&](size_t, size_t) { return dist(rnd); }); +} + +std::vector fill_matrix_random(size_t height, size_t width, const DataFormat& format, uint64_t seed) { + switch (format.quantization_format()) { + case DataFormat::QuantizationFormat::NONE: + switch (format.data_type()) { + case DataType::FP32: + return fill_matrix_random_raw(height, width, seed); + + default: + KAI_ERROR("Unsupported data type!"); + } + + break; + + default: + KAI_ERROR("Unsupported data format!"); + } +} + +} // namespace kai::test diff --git a/test/reference/fill.hpp b/test/reference/fill.hpp new file mode 100644 index 00000000..09138145 --- /dev/null +++ b/test/reference/fill.hpp @@ -0,0 +1,27 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include + +namespace kai::test { + +class DataFormat; + +/// Creates a new matrix filled with random data. +/// +/// @param[in] height Number of rows. +/// @param[in] width Number of columns. +/// @param[in] format Data format. +/// @param[in] seed Random seed. +/// +/// @return The data buffer for the matrix. +std::vector fill_matrix_random(size_t height, size_t width, const DataFormat& format, uint64_t seed); + +} // namespace kai::test diff --git a/test/reference/pack.cpp b/test/reference/pack.cpp new file mode 100644 index 00000000..27b3cae6 --- /dev/null +++ b/test/reference/pack.cpp @@ -0,0 +1,114 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/reference/pack.hpp" + +#include +#include +#include +#include +#include + +#include "src/kai_common.h" +#include "test/common/data_format.hpp" +#include "test/common/data_type.hpp" +#include "test/reference/quantize.hpp" + +namespace kai::test { + +namespace { + +/// Packs the matrix from raw to quantized format. +template +std::vector pack_quant_per_row( + const void* src, size_t height, size_t width, size_t block_height, size_t block_width) { + const auto num_groups = (height + block_height - 1) / block_height; + const auto group_num_blocks = (width + block_width - 1) / block_width; + + const auto group_offsets_bytes = block_height * sizeof(Offset); + const auto group_scales_bytes = block_height * sizeof(Scale); + const auto block_data_bytes = block_height * block_width * sizeof(Output); + const auto group_bytes = group_offsets_bytes + group_num_blocks * block_data_bytes + group_scales_bytes; + const auto dst_bytes = num_groups * group_bytes; + + std::vector dst; + dst.resize(dst_bytes); + + const auto* src_ptr = reinterpret_cast(src); + auto* dst_ptr = dst.data(); + + for (size_t group_no = 0; group_no < num_groups; ++group_no) { + Scale scales[block_height]; + Offset offsets[block_height]; + + // Finds the range of values and calculates the quantization information. + for (size_t y = 0; y < block_height; ++y) { + auto min_value = std::numeric_limits::max(); + auto max_value = std::numeric_limits::lowest(); + + for (size_t x = 0; x < width; ++x) { + const auto value = src_ptr[(group_no * block_height + y) * width + x]; + + if (value < min_value) { + min_value = value; + } + + if (value > max_value) { + max_value = value; + } + } + + std::tie(scales[y], offsets[y]) = get_qi8_scale_offset_from_range(min_value, max_value); + } + + // Packs the offsets. + memcpy(dst_ptr, offsets, group_offsets_bytes); + dst_ptr += group_offsets_bytes; + + // Quantizes and packs the data. + for (size_t x_block = 0; x_block < group_num_blocks; ++x_block) { + for (size_t block_y = 0; block_y < block_height; ++block_y) { + for (size_t block_x = 0; block_x < block_width; ++block_x) { + const auto value = + src_ptr[(group_no * block_height + block_y) * width + x_block * block_width + block_x]; + const auto qvalue = quantize_i8_fp32(value, scales[block_y], offsets[block_y]); + *reinterpret_cast(dst_ptr) = qvalue; + ++dst_ptr; + } + } + } + + // Packs the scales. + memcpy(dst_ptr, scales, group_scales_bytes); + dst_ptr += group_scales_bytes; + } + + KAI_ASSERT(reinterpret_cast(dst_ptr) - reinterpret_cast(dst.data()) == dst_bytes); + + return dst; +} + +} // namespace + +std::vector pack( + const DataFormat& dst_format, const void* src, const DataFormat& src_format, size_t height, size_t width) { + const auto dst_dt = dst_format.data_type(); + const auto dst_qf = dst_format.quantization_format(); + const auto src_dt = src_format.data_type(); + const auto src_qf = src_format.quantization_format(); + + if (src_qf == DataFormat::QuantizationFormat::NONE && dst_qf == DataFormat::QuantizationFormat::PER_ROW) { + if (dst_dt == DataType::QI8 && src_dt == DataType::FP32 && dst_format.scale_data_type() == DataType::FP32 && + dst_format.offset_data_type() == DataType::I32) { + return pack_quant_per_row( + src, height, width, dst_format.block_height(), dst_format.block_width()); + } + } + + KAI_ERROR("Unsupported operation!"); +} + +} // namespace kai::test diff --git a/test/reference/pack.hpp b/test/reference/pack.hpp new file mode 100644 index 00000000..cfd2540d --- /dev/null +++ b/test/reference/pack.hpp @@ -0,0 +1,27 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include + +namespace kai::test { + +class DataFormat; + +/// Packs the matrix. +/// +/// @param[in] dst_format Data format of the destination matrix. +/// @param[in] src Data buffer of the source matrix. +/// @param[in] src_format Data format of the source matrix. +/// @param[in] height Number of rows of the source matrix. +/// @param[in] width Number of columns of the source matrix. +std::vector pack( + const DataFormat& dst_format, const void* src, const DataFormat& src_format, size_t height, size_t width); + +} // namespace kai::test diff --git a/test/reference/quantize.cpp b/test/reference/quantize.cpp new file mode 100644 index 00000000..67c7037e --- /dev/null +++ b/test/reference/quantize.cpp @@ -0,0 +1,48 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include + +#include "test/reference/round.hpp" + +namespace kai::test { + +std::tuple get_qi8_scale_offset_from_range(float min_value, float max_value) { + constexpr float q_min = std::numeric_limits::min(); + constexpr float q_max = std::numeric_limits::max(); + + if (min_value > 0) { + min_value = 0; + } + + if (max_value < 0) { + max_value = 0; + } + + // The reason for computing the inverted scale first is to make it bit-perfect with quantized packing kernels. + // If those kernels don't do it this way anymore, it makes more sense to calculate the scale directly. + const float inv_scale = max_value != min_value ? (q_max - q_min) / (max_value - min_value) : 1.0f; + const float scale = 1.0f / inv_scale; + + const float scaled_min = min_value / scale; + const float scaled_max = max_value / scale; + + const float offset_f = -(scaled_min + q_min) < scaled_max + q_max ? scaled_min - q_min : scaled_max - q_max; + const int32_t offset = round_to_nearest_even_i32(offset_f); + + return {scale, offset}; +} + +int8_t quantize_i8_fp32(float value, float scale, int32_t offset) { + return std::clamp( + round_to_nearest_even_i32(value / scale) - offset, std::numeric_limits::min(), + std::numeric_limits::max()); +} + +} // namespace kai::test diff --git a/test/reference/quantize.hpp b/test/reference/quantize.hpp new file mode 100644 index 00000000..46c4bf8f --- /dev/null +++ b/test/reference/quantize.hpp @@ -0,0 +1,33 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +namespace kai::test { + +/// Calculates the quantization information for 8-bit signed asymmetric type from the value range. +/// +/// @param[in] min_value Minimum value. +/// @param[in] max_value Maximum value. +/// +/// @return The scale and offset. +std::tuple get_qi8_scale_offset_from_range(float min_value, float max_value); + +/// Quantizes the single-precision floating-point value using 8-bit asymmetric quantization. +/// +/// Formula: `q = f / scale + offset` where `q` is quantized value and `f` is floating-point value. +/// +/// @param[in] value Value to be quantized. +/// @param[in] scale Scale. +/// @param[in] offset Offset. +/// +/// @return The quantized value. +int8_t quantize_i8_fp32(float value, float scale, int32_t offset); + +} // namespace kai::test diff --git a/test/reference/round.cpp b/test/reference/round.cpp new file mode 100644 index 00000000..c6fb493d --- /dev/null +++ b/test/reference/round.cpp @@ -0,0 +1,36 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/reference/round.hpp" + +#include +#include + +namespace kai::test { + +int32_t round_to_nearest_even_i32(float value) { + int32_t rounded = 0; + asm("fcvtns %w[output], %s[input]" : [output] "=w"(rounded) : [input] "w"(value)); + return rounded; +} + +size_t round_to_nearest_even_usize(float value) { + static_assert(sizeof(size_t) == sizeof(uint64_t)); + + uint64_t rounded = 0; + asm("fcvtns %x[output], %s[input]" : [output] "=r"(rounded) : [input] "w"(value)); + return rounded; +} + +size_t round_up_multiple(size_t a, size_t b) { + return ((a + b - 1) / b) * b; +} + +size_t round_down_multiple(size_t a, size_t b) { + return (a / b) * b; +} + +} // namespace kai::test diff --git a/test/reference/round.hpp b/test/reference/round.hpp new file mode 100644 index 00000000..7503936f --- /dev/null +++ b/test/reference/round.hpp @@ -0,0 +1,62 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +namespace kai::test { + +/// Rounds the specified value to nearest with tie to even. +/// +/// For example: +/// +/// * 0.4 is rounded to 0. +/// * 0.5 is rounded to 0 (as 0 is the nearest even value). +/// * 0.6 is rounded to 1. +/// * 1.4 is rounded to 1. +/// * 1.5 is rounded to 2 (as 2 is the nearest even value). +/// * 1.6 is rounded to 2. +/// +/// @param[in] value Value to be rounded. +/// +/// @return The rounded value. +int32_t round_to_nearest_even_i32(float value); + +/// Rounds the specified value to nearest with tie to even. +/// +/// For example: +/// +/// * 0.4 is rounded to 0. +/// * 0.5 is rounded to 0 (as 0 is the nearest even value). +/// * 0.6 is rounded to 1. +/// * 1.4 is rounded to 1. +/// * 1.5 is rounded to 2 (as 2 is the nearest even value). +/// * 1.6 is rounded to 2. +/// +/// @param[in] value Value to be rounded. +/// +/// @return The rounded value. +size_t round_to_nearest_even_usize(float value); + +/// Rounds up the input value to the multiple of the unit value. +/// +/// @param[in] a Input value. +/// @param[in] b Unit value. +/// +/// @return The rounded value. +size_t round_up_multiple(size_t a, size_t b); + +/// Rounds down the input value to the multiple of the unit value. +/// +/// @param[in] a Input value. +/// @param[in] b Unit value. +/// +/// @return The rounded value. +size_t round_down_multiple(size_t a, size_t b); + +} // namespace kai::test diff --git a/test/sample.cpp b/test/sample.cpp deleted file mode 100644 index d8e63383..00000000 --- a/test/sample.cpp +++ /dev/null @@ -1,10 +0,0 @@ -// -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates -// -// SPDX-License-Identifier: Apache-2.0 -// -#include - -TEST(SampleSuite, SampleTest) { - EXPECT_EQ(1 + 2, 3); -} diff --git a/test/tests/matmul_test.cpp b/test/tests/matmul_test.cpp new file mode 100644 index 00000000..0091ef77 --- /dev/null +++ b/test/tests/matmul_test.cpp @@ -0,0 +1,190 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "src/matmul/kai_lhs_quant_pack_qa8dxP4X8_f32.h" +#include "test/common/compare.hpp" +#include "test/common/data_format.hpp" +#include "test/common/data_type.hpp" +#include "test/common/matrix_portion.hpp" +#include "test/common/printer.hpp" +#include "test/reference/fill.hpp" +#include "test/reference/pack.hpp" + +namespace kai::test { + +/// Matrix multiplication method. +struct MatMulMethod { + size_t m0; ///< Block size in M dimension. + size_t n0; ///< Block size in N dimension. + + DataFormat dst_format; ///< Data format of the destination matrix. + DataFormat lhs_format; ///< Data format of the LHS matrix. + DataFormat p_lhs_format; ///< Data format of the preprocessed LHS matrix. + DataFormat rhs_format; ///< Data format of the RHS matrix. + + /// Gets the offset in bytes of the LHS matrix. + /// + /// @param[in] m_idx Coordinate of the matrix in M dimension. + /// @param[in] stride Row stride in bytes. + std::function fn_get_lhs_offset; + + /// Gets the size in bytes of the packed LHS matrix. + /// + /// @param[in] m Size of the matrix in M dimension. + /// @param[in] k Size of the matrix in K dimension. + std::function fn_get_preprocessed_lhs_size; + + /// Gets the offset in bytes of the packed LHS matrix. + /// + /// @param[in] m_idx Coordinate of the matrix in M dimension. + /// @param[in] k Size of the matrix in K dimension. + std::function fn_get_preprocessed_lhs_offset; + + /// Preprocesses the LHS matrix. + /// + /// @param[in] m Size of the matrix in M dimension. + /// @param[in] k Size of the matrix in K dimension. + /// @param[in] lhs LHS matrix data buffer. + /// @param[in] lhs_row_stride Row stride in bytes of the LHS matrix. + /// @param[in] preprocessed_lhs Preprocessed LHS matrix data buffer. + std::function + fn_preprocess_lhs; +}; + +/// List of supported matrix multiplication methods. +static const std::array matmul_methods = { + MatMulMethod{ + .m0 = 4, + .n0 = 4, + + .dst_format = DataFormat(DataType::FP32), + .lhs_format = DataFormat(DataType::FP32), + .p_lhs_format = + DataFormat(DataType::QI8, 4, 8, DataFormat::QuantizationFormat::PER_ROW, DataType::FP32, DataType::I32), + .rhs_format = DataFormat(DataType::UNKNOWN), // Unused. + + .fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qa8dxP4X8_f32, + .fn_get_preprocessed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qa8dxP4X8_f32, + .fn_get_preprocessed_lhs_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qa8dxP4X8_f32, + .fn_preprocess_lhs = kai_run_lhs_quant_pack_qa8dxP4X8_f32, + }, +}; + +/// Matrix multiplication shape. +struct MatMulShape { + size_t m; ///< LHS height. + size_t n; ///< RHS width. + size_t k; ///< LHS width and RHS height. +}; + +/// Matrix multiplication test information. +using MatMulTestParams = std::tuple; + +/// Matrix multiplication test fixture. +class MatMulTest : public testing::TestWithParam { +private: + /// Unique ID: m, n, k, method_id. + using TestDataId = std::tuple; + +protected: + /// Cached test data that is shared between multiple test case. + struct TestData { + std::vector lhs; ///< LHS operand. + std::vector ref_p_lhs; ///< Reference packed LHS. + }; + + /// Gets the test data for the current test case. + TestData& test_data() { + const auto& [info, method_no, portion] = GetParam(); + const TestDataId data_id{info.m, info.n, info.k, method_no}; + + // If the test data is already available, returns it. + const auto data_it = _data.find(data_id); + + if (data_it != _data.end()) { + return data_it->second; + } + + // Generates the test data. + const auto& method = matmul_methods.at(method_no); + + auto lhs = fill_matrix_random(info.m, info.k, method.lhs_format, 0); + auto ref_p_lhs = pack(method.p_lhs_format, lhs.data(), method.lhs_format, info.m, info.k); + + return _data[data_id] = {std::move(lhs), std::move(ref_p_lhs)}; + } + +private: + std::map _data; +}; + +/// Tests the LHS packing kernel. +TEST_P(MatMulTest, PackedLhs) { + const auto& [info, method_no, portion] = GetParam(); + const auto& data = test_data(); + const auto& method = matmul_methods.at(method_no); + + if (method.fn_preprocess_lhs == nullptr) { + GTEST_SKIP(); + } + + const auto [rect_start_row, rect_start_col, rect_height, rect_width] = portion.compute_portion( + info.m, info.k, method.p_lhs_format.scheduler_block_height(info.m), + method.p_lhs_format.scheduler_block_width(info.k)); + + if (rect_height == 0 || rect_width == 0) { + GTEST_SKIP(); + } + + const auto ref_lhs_row_stride = method.lhs_format.default_row_stride(info.k); + + const auto p_lhs_size = method.fn_get_preprocessed_lhs_size(info.m, info.k); + const auto ref_p_lhs_size = method.p_lhs_format.default_size_in_bytes(info.m, info.k); + ASSERT_EQ(p_lhs_size, ref_p_lhs_size); + + const auto lhs_offset = method.fn_get_lhs_offset(rect_start_row, ref_lhs_row_stride); + const auto ref_lhs_offset = method.lhs_format.default_offset_in_bytes(rect_start_row, 0, info.k); + ASSERT_EQ(lhs_offset, ref_lhs_offset); + + const auto p_lhs_offset = method.fn_get_preprocessed_lhs_offset(rect_start_row, info.k); + const auto ref_p_lhs_offset = method.p_lhs_format.default_offset_in_bytes(rect_start_row, 0, info.k); + ASSERT_EQ(p_lhs_offset, ref_p_lhs_offset); + + std::vector p_lhs; + p_lhs.resize(p_lhs_size); + method.fn_preprocess_lhs( + rect_height, rect_width, data.lhs.data() + lhs_offset, ref_lhs_row_stride, p_lhs.data() + p_lhs_offset); + + const auto success = compare( + p_lhs.data(), data.ref_p_lhs.data(), method.p_lhs_format, info.m, info.k, portion, + DefaultMismatchHandler(0, 0.0001, 0, 0.001)); + ASSERT_TRUE(success); +} + +INSTANTIATE_TEST_SUITE_P( + MatMul, MatMulTest, + testing::Combine( + testing::Values( + MatMulShape{4, 4, 32}, // + MatMulShape{12, 16, 48}), + testing::Range(0, matmul_methods.size()), + testing::Values( + MatrixPortion(0, 0, 1, 1), // Full matrix. + MatrixPortion(0, 0, 0.25, 0.25), // Top-left corner. + MatrixPortion(0.75, 0.75, 1, 1) // Bottom-right corner. + ))); + +} // namespace kai::test -- GitLab From 9b52afd59311f604659c8153e241b775bc68531c Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Wed, 17 Apr 2024 10:59:59 +0100 Subject: [PATCH 2/3] Fix typos and improve variable names Signed-off-by: Viet-Hoa Do --- test/common/logging.hpp | 2 +- test/common/matrix_portion.cpp | 14 +++++++------- test/common/matrix_portion.hpp | 6 +++--- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/test/common/logging.hpp b/test/common/logging.hpp index 85902bc6..ae1f11c0 100644 --- a/test/common/logging.hpp +++ b/test/common/logging.hpp @@ -9,7 +9,7 @@ #include #include -#define KAI_LOGE(...) kai::test::detail::log("DEBUG", __VA_ARGS__) +#define KAI_LOGE(...) kai::test::detail::log("ERROR", __VA_ARGS__) namespace kai::test::detail { diff --git a/test/common/matrix_portion.cpp b/test/common/matrix_portion.cpp index d886af34..a6844ed3 100644 --- a/test/common/matrix_portion.cpp +++ b/test/common/matrix_portion.cpp @@ -19,7 +19,7 @@ MatrixPortion::MatrixPortion(float start_row, float start_col, float height, flo } std::tuple MatrixPortion::compute_portion( - size_t full_height, size_t full_width, size_t block_height, size_t block_width) const { + size_t full_height, size_t full_width, size_t scheduler_block_height, size_t scheduler_block_width) const { const auto start_row_f = std::clamp(_start_row, 0, 1); const auto start_col_f = std::clamp(_start_col, 0, 1); const auto height_f = std::clamp(_height, 0, 1); @@ -30,14 +30,14 @@ std::tuple MatrixPortion::compute_portion( auto height = round_to_nearest_even_usize(height_f * full_height); auto width = round_to_nearest_even_usize(width_f * full_width); - start_row = round_down_multiple(start_row, block_height); - start_col = round_down_multiple(start_col, block_width); + start_row = round_down_multiple(start_row, scheduler_block_height); + start_col = round_down_multiple(start_col, scheduler_block_width); - start_row = std::min(start_row, round_down_multiple(full_height, block_height)); - start_col = std::min(start_col, round_down_multiple(full_width, block_width)); + start_row = std::min(start_row, round_down_multiple(full_height, scheduler_block_height)); + start_col = std::min(start_col, round_down_multiple(full_width, scheduler_block_width)); - height = round_up_multiple(height, block_height); - width = round_up_multiple(width, block_width); + height = round_up_multiple(height, scheduler_block_height); + width = round_up_multiple(width, scheduler_block_width); height = std::min(height, full_height - start_row); width = std::min(width, full_width - start_col); diff --git a/test/common/matrix_portion.hpp b/test/common/matrix_portion.hpp index e0a37dcd..3b38220a 100644 --- a/test/common/matrix_portion.hpp +++ b/test/common/matrix_portion.hpp @@ -36,12 +36,12 @@ public: /// /// @param[in] full_height Matrix height. /// @param[in] full_width Matrix width. - /// @param[in] block_height Block height. - /// @param[in] block_width Block width. + /// @param[in] scheduler_block_height Block height for scheduling purpose. + /// @param[in] scheduler_block_width Block width for scheduling purpose. /// /// @return The starting row, starting column, height and width of the sub-matrix. [[nodiscard]] std::tuple compute_portion( - size_t full_height, size_t full_width, size_t block_height, size_t block_width) const; + size_t full_height, size_t full_width, size_t scheduler_block_height, size_t scheduler_block_width) const; private: float _start_row; -- GitLab From 06496e3ab7d08ee21db654d2b853466362922d33 Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Wed, 24 Apr 2024 14:05:58 +0100 Subject: [PATCH 3/3] Improve variable names * Use packed for pre-processed matmul operands. * Use zero point instead of offset in quantization information. Signed-off-by: Viet-Hoa Do --- test/common/compare.cpp | 2 +- test/common/data_format.cpp | 12 ++++---- test/common/data_format.hpp | 10 +++---- test/common/printer.cpp | 4 +-- test/reference/pack.cpp | 20 ++++++------- test/reference/quantize.cpp | 12 ++++---- test/reference/quantize.hpp | 10 +++---- test/tests/matmul_test.cpp | 56 ++++++++++++++++++------------------- 8 files changed, 63 insertions(+), 63 deletions(-) diff --git a/test/common/compare.cpp b/test/common/compare.cpp index e9054b51..b81ca52b 100644 --- a/test/common/compare.cpp +++ b/test/common/compare.cpp @@ -143,7 +143,7 @@ bool compare( const MatrixPortion& portion, MismatchHandler&& handler) { const auto data_type = format.data_type(); const auto scale_dt = format.scale_data_type(); - const auto offset_dt = format.offset_data_type(); + const auto offset_dt = format.zero_point_data_type(); switch (format.quantization_format()) { case DataFormat::QuantizationFormat::PER_ROW: diff --git a/test/common/data_format.cpp b/test/common/data_format.cpp index 35fddaaa..9310ad93 100644 --- a/test/common/data_format.cpp +++ b/test/common/data_format.cpp @@ -16,18 +16,18 @@ namespace kai::test { DataFormat::DataFormat( DataType data_type, size_t block_height, size_t block_width, QuantizationFormat quant_format, DataType scale_dt, - DataType offset_dt) noexcept : + DataType zero_point_dt) noexcept : _data_type(data_type), _quant_format(quant_format), _scale_dt(scale_dt), - _offset_dt(offset_dt), + _zero_point_dt(zero_point_dt), _block_height(block_height), _block_width(block_width) { } bool DataFormat::operator==(const DataFormat& rhs) const { return _data_type == rhs._data_type && _quant_format == rhs._quant_format && _scale_dt == rhs._scale_dt && - _offset_dt == rhs._offset_dt && _block_height == rhs._block_height && _block_width == rhs._block_width; + _zero_point_dt == rhs._zero_point_dt && _block_height == rhs._block_height && _block_width == rhs._block_width; } bool DataFormat::operator!=(const DataFormat& rhs) const { @@ -46,8 +46,8 @@ DataType DataFormat::scale_data_type() const { return _scale_dt; } -DataType DataFormat::offset_data_type() const { - return _offset_dt; +DataType DataFormat::zero_point_data_type() const { + return _zero_point_dt; } size_t DataFormat::block_height() const { @@ -90,7 +90,7 @@ uintptr_t DataFormat::default_row_stride(size_t width) const { return padded_width * data_type_size_in_bits(_data_type) / 8; case QuantizationFormat::PER_ROW: - return _block_height * data_type_size_in_bits(_offset_dt) / 8 + // + return _block_height * data_type_size_in_bits(_zero_point_dt) / 8 + // _block_height * padded_width * data_type_size_in_bits(_data_type) / 8 + // _block_height * data_type_size_in_bits(_scale_dt) / 8; diff --git a/test/common/data_format.hpp b/test/common/data_format.hpp index bd9cfb7e..24a17574 100644 --- a/test/common/data_format.hpp +++ b/test/common/data_format.hpp @@ -26,13 +26,13 @@ public: /// @param[in] data_type Data type of data value. /// @param[in] quant_format Quantization packing format. /// @param[in] scale_dt Data type of scale value. - /// @param[in] offset_dt Data type of offset value. + /// @param[in] zero_point_dt Data type of zero point value. /// @param[in] block_height Block height. /// @param[in] block_width Block width. DataFormat( DataType data_type, size_t block_height = 1, size_t block_width = 1, QuantizationFormat quant_format = QuantizationFormat::NONE, DataType scale_dt = DataType::UNKNOWN, - DataType offset_dt = DataType::UNKNOWN) noexcept; + DataType zero_point_dt = DataType::UNKNOWN) noexcept; /// Equality operator. [[nodiscard]] bool operator==(const DataFormat& rhs) const; @@ -49,8 +49,8 @@ public: /// Gets the data type of scale value. [[nodiscard]] DataType scale_data_type() const; - /// Gets the data type of offset value. - [[nodiscard]] DataType offset_data_type() const; + /// Gets the data type of zero point value. + [[nodiscard]] DataType zero_point_data_type() const; /// Gets the block height. [[nodiscard]] size_t block_height() const; @@ -104,7 +104,7 @@ private: DataType _data_type; QuantizationFormat _quant_format; DataType _scale_dt; - DataType _offset_dt; + DataType _zero_point_dt; size_t _block_height; size_t _block_width; }; diff --git a/test/common/printer.cpp b/test/common/printer.cpp index 25386553..256ea970 100644 --- a/test/common/printer.cpp +++ b/test/common/printer.cpp @@ -58,13 +58,13 @@ void print_matrix_per_row( const auto num_blocks = (height + block_height - 1) / block_height; const auto block_data_bytes = block_height * width * data_type_size_in_bits(format.data_type()) / 8; - const auto block_offsets_bytes = block_height * data_type_size_in_bits(format.offset_data_type()) / 8; + const auto block_offsets_bytes = block_height * data_type_size_in_bits(format.zero_point_data_type()) / 8; const auto block_scales_bytes = block_height * data_type_size_in_bits(format.scale_data_type()) / 8; os << "[\n"; for (size_t y = 0; y < num_blocks; ++y) { os << " {\"offsets\": ["; - print_data(os, data, block_height, format.offset_data_type()); + print_data(os, data, block_height, format.zero_point_data_type()); os << "], \"data\": ["; print_data(os, data + block_offsets_bytes, block_height * width, format.data_type()); os << "], \"scales\": ["; diff --git a/test/reference/pack.cpp b/test/reference/pack.cpp index 27b3cae6..81c7c664 100644 --- a/test/reference/pack.cpp +++ b/test/reference/pack.cpp @@ -22,16 +22,16 @@ namespace kai::test { namespace { /// Packs the matrix from raw to quantized format. -template +template std::vector pack_quant_per_row( const void* src, size_t height, size_t width, size_t block_height, size_t block_width) { const auto num_groups = (height + block_height - 1) / block_height; const auto group_num_blocks = (width + block_width - 1) / block_width; - const auto group_offsets_bytes = block_height * sizeof(Offset); + const auto group_zero_points_bytes = block_height * sizeof(ZeroPoint); const auto group_scales_bytes = block_height * sizeof(Scale); const auto block_data_bytes = block_height * block_width * sizeof(Output); - const auto group_bytes = group_offsets_bytes + group_num_blocks * block_data_bytes + group_scales_bytes; + const auto group_bytes = group_zero_points_bytes + group_num_blocks * block_data_bytes + group_scales_bytes; const auto dst_bytes = num_groups * group_bytes; std::vector dst; @@ -42,7 +42,7 @@ std::vector pack_quant_per_row( for (size_t group_no = 0; group_no < num_groups; ++group_no) { Scale scales[block_height]; - Offset offsets[block_height]; + ZeroPoint zero_points[block_height]; // Finds the range of values and calculates the quantization information. for (size_t y = 0; y < block_height; ++y) { @@ -61,12 +61,12 @@ std::vector pack_quant_per_row( } } - std::tie(scales[y], offsets[y]) = get_qi8_scale_offset_from_range(min_value, max_value); + std::tie(scales[y], zero_points[y]) = get_qi8_scale_zero_point_from_range(min_value, max_value); } - // Packs the offsets. - memcpy(dst_ptr, offsets, group_offsets_bytes); - dst_ptr += group_offsets_bytes; + // Packs the zero points. + memcpy(dst_ptr, zero_points, group_zero_points_bytes); + dst_ptr += group_zero_points_bytes; // Quantizes and packs the data. for (size_t x_block = 0; x_block < group_num_blocks; ++x_block) { @@ -74,7 +74,7 @@ std::vector pack_quant_per_row( for (size_t block_x = 0; block_x < block_width; ++block_x) { const auto value = src_ptr[(group_no * block_height + block_y) * width + x_block * block_width + block_x]; - const auto qvalue = quantize_i8_fp32(value, scales[block_y], offsets[block_y]); + const auto qvalue = quantize_i8_fp32(value, scales[block_y], zero_points[block_y]); *reinterpret_cast(dst_ptr) = qvalue; ++dst_ptr; } @@ -102,7 +102,7 @@ std::vector pack( if (src_qf == DataFormat::QuantizationFormat::NONE && dst_qf == DataFormat::QuantizationFormat::PER_ROW) { if (dst_dt == DataType::QI8 && src_dt == DataType::FP32 && dst_format.scale_data_type() == DataType::FP32 && - dst_format.offset_data_type() == DataType::I32) { + dst_format.zero_point_data_type() == DataType::I32) { return pack_quant_per_row( src, height, width, dst_format.block_height(), dst_format.block_width()); } diff --git a/test/reference/quantize.cpp b/test/reference/quantize.cpp index 67c7037e..98b29af5 100644 --- a/test/reference/quantize.cpp +++ b/test/reference/quantize.cpp @@ -13,7 +13,7 @@ namespace kai::test { -std::tuple get_qi8_scale_offset_from_range(float min_value, float max_value) { +std::tuple get_qi8_scale_zero_point_from_range(float min_value, float max_value) { constexpr float q_min = std::numeric_limits::min(); constexpr float q_max = std::numeric_limits::max(); @@ -33,15 +33,15 @@ std::tuple get_qi8_scale_offset_from_range(float min_value, floa const float scaled_min = min_value / scale; const float scaled_max = max_value / scale; - const float offset_f = -(scaled_min + q_min) < scaled_max + q_max ? scaled_min - q_min : scaled_max - q_max; - const int32_t offset = round_to_nearest_even_i32(offset_f); + const float zero_point_f = -(scaled_min + q_min) < scaled_max + q_max ? scaled_min - q_min : scaled_max - q_max; + const int32_t zero_point = round_to_nearest_even_i32(zero_point_f); - return {scale, offset}; + return {scale, zero_point}; } -int8_t quantize_i8_fp32(float value, float scale, int32_t offset) { +int8_t quantize_i8_fp32(float value, float scale, int32_t zero_point) { return std::clamp( - round_to_nearest_even_i32(value / scale) - offset, std::numeric_limits::min(), + round_to_nearest_even_i32(value / scale) - zero_point, std::numeric_limits::min(), std::numeric_limits::max()); } diff --git a/test/reference/quantize.hpp b/test/reference/quantize.hpp index 46c4bf8f..b073d460 100644 --- a/test/reference/quantize.hpp +++ b/test/reference/quantize.hpp @@ -16,18 +16,18 @@ namespace kai::test { /// @param[in] min_value Minimum value. /// @param[in] max_value Maximum value. /// -/// @return The scale and offset. -std::tuple get_qi8_scale_offset_from_range(float min_value, float max_value); +/// @return The scale and zero point. +std::tuple get_qi8_scale_zero_point_from_range(float min_value, float max_value); /// Quantizes the single-precision floating-point value using 8-bit asymmetric quantization. /// -/// Formula: `q = f / scale + offset` where `q` is quantized value and `f` is floating-point value. +/// Formula: `q = f / scale + zero_point` where `q` is quantized value and `f` is floating-point value. /// /// @param[in] value Value to be quantized. /// @param[in] scale Scale. -/// @param[in] offset Offset. +/// @param[in] zero_point Zero point. /// /// @return The quantized value. -int8_t quantize_i8_fp32(float value, float scale, int32_t offset); +int8_t quantize_i8_fp32(float value, float scale, int32_t zero_point); } // namespace kai::test diff --git a/test/tests/matmul_test.cpp b/test/tests/matmul_test.cpp index 0091ef77..d8be126a 100644 --- a/test/tests/matmul_test.cpp +++ b/test/tests/matmul_test.cpp @@ -30,10 +30,10 @@ struct MatMulMethod { size_t m0; ///< Block size in M dimension. size_t n0; ///< Block size in N dimension. - DataFormat dst_format; ///< Data format of the destination matrix. - DataFormat lhs_format; ///< Data format of the LHS matrix. - DataFormat p_lhs_format; ///< Data format of the preprocessed LHS matrix. - DataFormat rhs_format; ///< Data format of the RHS matrix. + DataFormat dst_format; ///< Data format of the destination matrix. + DataFormat lhs_format; ///< Data format of the LHS matrix. + DataFormat packed_lhs_format; ///< Data format of the packed LHS matrix. + DataFormat rhs_format; ///< Data format of the RHS matrix. /// Gets the offset in bytes of the LHS matrix. /// @@ -45,13 +45,13 @@ struct MatMulMethod { /// /// @param[in] m Size of the matrix in M dimension. /// @param[in] k Size of the matrix in K dimension. - std::function fn_get_preprocessed_lhs_size; + std::function fn_get_packed_lhs_size; /// Gets the offset in bytes of the packed LHS matrix. /// /// @param[in] m_idx Coordinate of the matrix in M dimension. /// @param[in] k Size of the matrix in K dimension. - std::function fn_get_preprocessed_lhs_offset; + std::function fn_get_packed_lhs_offset; /// Preprocesses the LHS matrix. /// @@ -59,9 +59,8 @@ struct MatMulMethod { /// @param[in] k Size of the matrix in K dimension. /// @param[in] lhs LHS matrix data buffer. /// @param[in] lhs_row_stride Row stride in bytes of the LHS matrix. - /// @param[in] preprocessed_lhs Preprocessed LHS matrix data buffer. - std::function - fn_preprocess_lhs; + /// @param[out] packed_lhs Packed LHS matrix data buffer. + std::function fn_preprocess_lhs; }; /// List of supported matrix multiplication methods. @@ -72,13 +71,13 @@ static const std::array matmul_methods = { .dst_format = DataFormat(DataType::FP32), .lhs_format = DataFormat(DataType::FP32), - .p_lhs_format = + .packed_lhs_format = DataFormat(DataType::QI8, 4, 8, DataFormat::QuantizationFormat::PER_ROW, DataType::FP32, DataType::I32), .rhs_format = DataFormat(DataType::UNKNOWN), // Unused. .fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qa8dxP4X8_f32, - .fn_get_preprocessed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qa8dxP4X8_f32, - .fn_get_preprocessed_lhs_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qa8dxP4X8_f32, + .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qa8dxP4X8_f32, + .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qa8dxP4X8_f32, .fn_preprocess_lhs = kai_run_lhs_quant_pack_qa8dxP4X8_f32, }, }; @@ -102,8 +101,8 @@ private: protected: /// Cached test data that is shared between multiple test case. struct TestData { - std::vector lhs; ///< LHS operand. - std::vector ref_p_lhs; ///< Reference packed LHS. + std::vector lhs; ///< LHS operand. + std::vector ref_packed_lhs; ///< Reference packed LHS. }; /// Gets the test data for the current test case. @@ -122,9 +121,9 @@ protected: const auto& method = matmul_methods.at(method_no); auto lhs = fill_matrix_random(info.m, info.k, method.lhs_format, 0); - auto ref_p_lhs = pack(method.p_lhs_format, lhs.data(), method.lhs_format, info.m, info.k); + auto ref_packed_lhs = pack(method.packed_lhs_format, lhs.data(), method.lhs_format, info.m, info.k); - return _data[data_id] = {std::move(lhs), std::move(ref_p_lhs)}; + return _data[data_id] = {std::move(lhs), std::move(ref_packed_lhs)}; } private: @@ -142,8 +141,8 @@ TEST_P(MatMulTest, PackedLhs) { } const auto [rect_start_row, rect_start_col, rect_height, rect_width] = portion.compute_portion( - info.m, info.k, method.p_lhs_format.scheduler_block_height(info.m), - method.p_lhs_format.scheduler_block_width(info.k)); + info.m, info.k, method.packed_lhs_format.scheduler_block_height(info.m), + method.packed_lhs_format.scheduler_block_width(info.k)); if (rect_height == 0 || rect_width == 0) { GTEST_SKIP(); @@ -151,25 +150,26 @@ TEST_P(MatMulTest, PackedLhs) { const auto ref_lhs_row_stride = method.lhs_format.default_row_stride(info.k); - const auto p_lhs_size = method.fn_get_preprocessed_lhs_size(info.m, info.k); - const auto ref_p_lhs_size = method.p_lhs_format.default_size_in_bytes(info.m, info.k); - ASSERT_EQ(p_lhs_size, ref_p_lhs_size); + const auto packed_lhs_size = method.fn_get_packed_lhs_size(info.m, info.k); + const auto ref_packed_lhs_size = method.packed_lhs_format.default_size_in_bytes(info.m, info.k); + ASSERT_EQ(packed_lhs_size, ref_packed_lhs_size); const auto lhs_offset = method.fn_get_lhs_offset(rect_start_row, ref_lhs_row_stride); const auto ref_lhs_offset = method.lhs_format.default_offset_in_bytes(rect_start_row, 0, info.k); ASSERT_EQ(lhs_offset, ref_lhs_offset); - const auto p_lhs_offset = method.fn_get_preprocessed_lhs_offset(rect_start_row, info.k); - const auto ref_p_lhs_offset = method.p_lhs_format.default_offset_in_bytes(rect_start_row, 0, info.k); - ASSERT_EQ(p_lhs_offset, ref_p_lhs_offset); + const auto packed_lhs_offset = method.fn_get_packed_lhs_offset(rect_start_row, info.k); + const auto ref_packed_lhs_offset = method.packed_lhs_format.default_offset_in_bytes(rect_start_row, 0, info.k); + ASSERT_EQ(packed_lhs_offset, ref_packed_lhs_offset); - std::vector p_lhs; - p_lhs.resize(p_lhs_size); + std::vector packed_lhs; + packed_lhs.resize(packed_lhs_size); method.fn_preprocess_lhs( - rect_height, rect_width, data.lhs.data() + lhs_offset, ref_lhs_row_stride, p_lhs.data() + p_lhs_offset); + rect_height, rect_width, data.lhs.data() + lhs_offset, ref_lhs_row_stride, + packed_lhs.data() + packed_lhs_offset); const auto success = compare( - p_lhs.data(), data.ref_p_lhs.data(), method.p_lhs_format, info.m, info.k, portion, + packed_lhs.data(), data.ref_packed_lhs.data(), method.packed_lhs_format, info.m, info.k, portion, DefaultMismatchHandler(0, 0.0001, 0, 0.001)); ASSERT_TRUE(success); } -- GitLab