diff --git a/CMakeLists.txt b/CMakeLists.txt index b6d9662c5e502801f5bae4fe8df606a7f4c818c9..5266a3a5d52f4405f060374360c8afb9331f54e4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -60,6 +60,9 @@ if(KLEIDIAI_BUILD_TESTS) test/common/data_format.cpp test/common/printer.cpp test/common/int4.cpp + test/common/compare.cpp + test/common/matrix_portion.cpp + test/common/rect.cpp test/reference/round.cpp ) diff --git a/test/common/compare.cpp b/test/common/compare.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f1dfecdf12e7fd0aaca424d7c8a19345e688c942 --- /dev/null +++ b/test/common/compare.cpp @@ -0,0 +1,272 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/common/compare.hpp" + +#include +#include +#include +#include + +#include "src/kai_common.h" +#include "test/common/data_format.hpp" +#include "test/common/data_type.hpp" +#include "test/common/int4.hpp" +#include "test/common/logging.hpp" +#include "test/common/memory.hpp" +#include "test/common/printer.hpp" +#include "test/common/rect.hpp" + +namespace kai::test { + +namespace { + +/// Calculates the absolute and relative errors. +/// +/// @param[in] imp Value under test. +/// @param[in] ref Reference value. +/// +/// @return The absolute error and relative error. +template +std::tuple calculate_error(T imp, T ref) { + const auto imp_f = static_cast(imp); + const auto ref_f = static_cast(ref); + + const auto abs_error = std::abs(imp_f - ref_f); + const auto rel_error = ref_f != 0 ? abs_error / std::abs(ref_f) : 0.0F; + + return {abs_error, rel_error}; +} + +/// Compares matrices with per-row quantization. +template +bool compare_raw( + const void* imp_data, const void* ref_data, size_t full_height, size_t full_width, const Rect& rect, + MismatchHandler& handler) { + for (size_t y = 0; y < full_height; ++y) { + for (size_t x = 0; x < full_width; ++x) { + const auto in_roi = + y >= rect.start_row() && y < rect.end_row() && x >= rect.start_col() && x < rect.end_col(); + + const auto imp_value = read_array(imp_data, y * full_width + x); + const auto ref_value = in_roi ? read_array(ref_data, y * full_width + x) : 0; + + const auto [abs_err, rel_err] = calculate_error(imp_value, ref_value); + + if (abs_err != 0 || rel_err != 0) { + const auto notifying = !in_roi || handler.handle_data(abs_err, rel_err); + + if (notifying) { + KAI_LOGE("Mismatched data at (", y, ", ", x, "): actual = ", imp_value, ", expected: ", ref_value); + } + } + } + } + + return handler.success(full_height * full_width); +} + +/// Compares matrices with per-row quantization. +template +bool compare_per_row( + const void* imp_data, const void* ref_data, const DataFormat& format, size_t full_height, size_t full_width, + const Rect& rect, MismatchHandler& handler) { + const auto block_height = format.block_height(); + const auto block_width = format.block_width(); + + KAI_ASSUME(format.scheduler_block_height(full_height) == block_height); + KAI_ASSUME(format.scheduler_block_width(full_width) == full_width); + KAI_ASSUME(rect.start_col() == 0); + KAI_ASSUME(rect.width() == full_width); + + const auto data_bits = size_in_bits; + + const auto num_groups = (full_height + block_height - 1) / block_height; + const auto group_num_blocks = (full_width + block_width - 1) / block_width; + + const auto group_offsets_bytes = block_height * sizeof(Offset); + const auto group_scales_bytes = block_height * sizeof(Scale); + const auto block_data_bytes = block_height * block_width * data_bits / 8; + + const auto begin_group = rect.start_row() / block_height; + const auto end_group = rect.end_row() / block_height; + + const auto* imp_ptr = reinterpret_cast(imp_data); + const auto* ref_ptr = reinterpret_cast(ref_data); + + for (size_t group_no = 0; group_no < num_groups; ++group_no) { + const auto in_roi = group_no >= begin_group && group_no < end_group; + + // Checks the quantization offsets. + for (size_t i = 0; i < block_height; ++i) { + const auto imp_offset = reinterpret_cast(imp_ptr)[i]; + const Offset ref_offset = in_roi ? reinterpret_cast(ref_ptr)[i] : 0; + const auto [abs_err, rel_err] = calculate_error(imp_offset, ref_offset); + + if (abs_err != 0 || rel_err != 0) { + handler.mark_as_failed(); + + const auto raw_row = group_no * block_height + i; + KAI_LOGE( + "Mismatched quantization offset ", raw_row, ": actual = ", imp_offset, ", expected: ", ref_offset); + } + } + + imp_ptr += group_offsets_bytes; + ref_ptr += group_offsets_bytes; + + // Checks the data. + for (size_t block_no = 0; block_no < group_num_blocks; ++block_no) { + for (size_t y = 0; y < block_height; ++y) { + for (size_t x = 0; x < block_width; ++x) { + const auto imp_data = read_array(imp_ptr, y * block_width + x); + const Data ref_data = in_roi ? read_array(ref_ptr, y * block_width + x) : Data(0); + const auto [abs_err, rel_err] = calculate_error(imp_data, ref_data); + + if (abs_err != 0 || rel_err != 0) { + const auto notifying = !in_roi || handler.handle_data(abs_err, rel_err); + + if (notifying) { + const auto raw_row = group_no * block_height + y; + const auto raw_col = block_no * block_width + x; + + KAI_LOGE( + "Mismatched data at (", raw_row, ", ", raw_col, "): actual = ", imp_data, + ", expected: ", ref_data); + } + } + } + } + + imp_ptr += block_data_bytes; + ref_ptr += block_data_bytes; + } + + // Checks the quantization scales. + for (size_t i = 0; i < block_height; ++i) { + const auto imp_scale = reinterpret_cast(imp_ptr)[i]; + const Scale ref_scale = in_roi ? reinterpret_cast(ref_ptr)[i] : 0; + const auto [abs_err, rel_err] = calculate_error(imp_scale, ref_scale); + + if (abs_err != 0 || rel_err != 0) { + handler.mark_as_failed(); + + const auto raw_row = group_no * block_height + i; + KAI_LOGE( + "Mismatched quantization scale ", raw_row, ": actual = ", imp_scale, ", expected: ", ref_scale); + } + } + + imp_ptr += group_scales_bytes; + ref_ptr += group_scales_bytes; + } + + return handler.success(rect.height() * full_width); +} + +} // namespace + +bool compare( + const void* imp_data, const void* ref_data, const DataFormat& format, size_t full_height, size_t full_width, + const Rect& rect, MismatchHandler& handler) { + const auto data_type = format.data_type(); + const auto scale_dt = format.scale_data_type(); + const auto offset_dt = format.zero_point_data_type(); + + switch (format.quantization_format()) { + case DataFormat::QuantizationFormat::NONE: + switch (data_type) { + case DataType::FP32: + return compare_raw(imp_data, ref_data, full_height, full_width, rect, handler); + + default: + break; + } + + break; + + case DataFormat::QuantizationFormat::PER_ROW: + if (data_type == DataType::QAI8 && scale_dt == DataType::FP32 && offset_dt == DataType::I32) { + return compare_per_row( + imp_data, ref_data, format, full_height, full_width, rect, handler); + } else if (data_type == DataType::QSI4 && scale_dt == DataType::FP32 && offset_dt == DataType::I32) { + return compare_per_row( + imp_data, ref_data, format, full_height, full_width, rect, handler); + } + + break; + + default: + break; + } + + KAI_ERROR("Unsupported format!"); +} + +// ===================================================================================================================== + +DefaultMismatchHandler::DefaultMismatchHandler( + float abs_error_threshold, float rel_error_threshold, size_t abs_mismatched_threshold, + float rel_mismatched_threshold) : + _abs_error_threshold(abs_error_threshold), + _rel_error_threshold(rel_error_threshold), + _abs_mismatched_threshold(abs_mismatched_threshold), + _rel_mismatched_threshold(rel_mismatched_threshold), + _num_mismatches(0), + _failed(false) { +} + +DefaultMismatchHandler::DefaultMismatchHandler(const DefaultMismatchHandler& rhs) : + _abs_error_threshold(rhs._abs_error_threshold), + _rel_error_threshold(rhs._rel_error_threshold), + _abs_mismatched_threshold(rhs._abs_mismatched_threshold), + _rel_mismatched_threshold(rhs._rel_mismatched_threshold), + _num_mismatches(0), + _failed(false) { + // Cannot copy mismatch handler that is already in use. + KAI_ASSUME(rhs._num_mismatches == 0); + KAI_ASSUME(!rhs._failed); +} + +DefaultMismatchHandler& DefaultMismatchHandler::operator=(const DefaultMismatchHandler& rhs) { + if (this != &rhs) { + // Cannot copy mismatch handler that is already in use. + KAI_ASSUME(rhs._num_mismatches == 0); + KAI_ASSUME(!rhs._failed); + + _abs_error_threshold = rhs._abs_error_threshold; + _rel_error_threshold = rhs._rel_error_threshold; + _abs_mismatched_threshold = rhs._abs_mismatched_threshold; + _rel_mismatched_threshold = rhs._rel_mismatched_threshold; + } + + return *this; +} + +bool DefaultMismatchHandler::handle_data(float absolute_error, float relative_error) { + const auto mismatched = absolute_error > _abs_error_threshold && relative_error > _rel_error_threshold; + + if (mismatched) { + ++_num_mismatches; + } + + return mismatched; +} + +void DefaultMismatchHandler::mark_as_failed() { + _failed = true; +} + +bool DefaultMismatchHandler::success(size_t num_checks) const { + if (_failed) { + return false; + } + + const auto mismatched_rate = static_cast(_num_mismatches) / static_cast(num_checks); + return _num_mismatches <= _abs_mismatched_threshold || mismatched_rate <= _rel_mismatched_threshold; +} + +} // namespace kai::test diff --git a/test/common/compare.hpp b/test/common/compare.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3bca0a5c30a9c28a78138723bb189567c344cfae --- /dev/null +++ b/test/common/compare.hpp @@ -0,0 +1,125 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +namespace kai::test { + +class DataFormat; +class Rect; +class MismatchHandler; + +/// Compares two matrices to check whether they are matched. +/// +/// @param[in] imp_data Data buffer of the actual implementation matrix. +/// @param[in] ref_data Data buffer of the reference implementation matrix. +/// @param[in] format Data format. +/// @param[in] full_height Height of the full matrix. +/// @param[in] full_width Width of the full matrix. +/// @param[in] rect Rectangular region of the matrix that is populated with data. +/// @param[in] handler Mismatch handler. +/// +/// @return `true` if the two matrices are considered matched. +bool compare( + const void* imp_data, const void* ref_data, const DataFormat& format, size_t full_height, size_t full_width, + const Rect& rect, MismatchHandler& handler); + +/// Handles mismatches found by @ref validate function. +class MismatchHandler { +public: + /// Constructor. + MismatchHandler() = default; + + /// Destructor. + virtual ~MismatchHandler() = default; + + /// Copy constructor. + MismatchHandler(const MismatchHandler&) = default; + + /// Copy assignment. + MismatchHandler& operator=(const MismatchHandler&) = default; + + /// Move constructor. + MismatchHandler(MismatchHandler&&) noexcept = default; + + /// Move assignment. + MismatchHandler& operator=(MismatchHandler&&) noexcept = default; + + /// Handles new mismatch result. + /// + /// This method must be called even when no error is detected. + /// + /// @param[in] absolute_error Absolute error. + /// @param[in] relative_error Relative error. + /// + /// @return `true` if the mismatch is sufficiently large to be logged as real mismatch. + virtual bool handle_data(float absolute_error, float relative_error) = 0; + + /// Marks the result as failed. + /// + /// It is zero tolerance if the data point is considered impossible to have mismatch + /// regardless of implementation method. + /// These normally include data point outside if the portion of interest (these must be 0) + /// and data point belongs to quantization information. + virtual void mark_as_failed() = 0; + + /// Returns a value indicating whether the two matrices are considered matched. + /// + /// @param[in] num_checks Total number of data points that have been checked. + /// + /// @return `true` if the two matrices are considered matched. + [[nodiscard]] virtual bool success(size_t num_checks) const = 0; +}; + +/// This mismatch handler considers two values being mismatched when both the relative error +/// and the absolute error exceed their respective thresholds. +/// +/// This mismatch handler considers two matrices being mismatched when the number of mismatches +/// exceed both the relative and absolute thresholds. +class DefaultMismatchHandler final : public MismatchHandler { +public: + /// Creates a new mismatch handler. + /// + /// @param[in] abs_error_threshold Threshold for absolute error + /// @param[in] rel_error_threshold Threshold for relative error. + /// @param[in] abs_mismatched_threshold Threshold for the number of mismatched data points. + /// @param[in] rel_mismatched_threshold Threshold for the ratio of mismatched data points. + DefaultMismatchHandler( + float abs_error_threshold, float rel_error_threshold, size_t abs_mismatched_threshold, + float rel_mismatched_threshold); + + /// Destructur. + ~DefaultMismatchHandler() = default; + + /// Copy constructor. + DefaultMismatchHandler(const DefaultMismatchHandler& rhs); + + /// Copy assignment. + DefaultMismatchHandler& operator=(const DefaultMismatchHandler& rhs); + + /// Move constructor. + DefaultMismatchHandler(DefaultMismatchHandler&& rhs) noexcept = default; + + /// Move assignment. + DefaultMismatchHandler& operator=(DefaultMismatchHandler&& rhs) noexcept = default; + + bool handle_data(float absolute_error, float relative_error) override; + void mark_as_failed() override; + [[nodiscard]] bool success(size_t num_checks) const override; + +private: + float _abs_error_threshold; + float _rel_error_threshold; + size_t _abs_mismatched_threshold; + float _rel_mismatched_threshold; + + size_t _num_mismatches; + bool _failed; +}; + +} // namespace kai::test diff --git a/test/common/matrix_portion.cpp b/test/common/matrix_portion.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e7c42bf130f3616707d790e9feae035582d73fce --- /dev/null +++ b/test/common/matrix_portion.cpp @@ -0,0 +1,65 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/common/matrix_portion.hpp" + +#include +#include + +#include "src/kai_common.h" +#include "test/common/rect.hpp" +#include "test/reference/round.hpp" + +namespace kai::test { + +MatrixPortion::MatrixPortion(float start_row, float start_col, float height, float width) : + _start_row(start_row), _start_col(start_col), _height(height), _width(width) { +} + +float MatrixPortion::start_row() const { + return _start_row; +} + +float MatrixPortion::start_col() const { + return _start_col; +} + +float MatrixPortion::height() const { + return _height; +} + +float MatrixPortion::width() const { + return _width; +} + +Rect MatrixPortion::compute_portion( + size_t full_height, size_t full_width, size_t scheduler_block_height, size_t scheduler_block_width) const { + KAI_ASSUME(_start_row >= 0.0F && _start_row <= 1.0F); + KAI_ASSUME(_start_col >= 0.0F && _start_col <= 1.0F); + KAI_ASSUME(_height >= 0.0F && _height <= 1.0F); + KAI_ASSUME(_width >= 0.0F && _width <= 1.0F); + + auto start_row = round_to_nearest_even_usize(_start_row * static_cast(full_height)); + auto start_col = round_to_nearest_even_usize(_start_col * static_cast(full_width)); + auto height = round_to_nearest_even_usize(_height * static_cast(full_height)); + auto width = round_to_nearest_even_usize(_width * static_cast(full_width)); + + start_row = round_down_multiple(start_row, scheduler_block_height); + start_col = round_down_multiple(start_col, scheduler_block_width); + + start_row = std::min(start_row, round_down_multiple(full_height, scheduler_block_height)); + start_col = std::min(start_col, round_down_multiple(full_width, scheduler_block_width)); + + height = round_up_multiple(height, scheduler_block_height); + width = round_up_multiple(width, scheduler_block_width); + + height = std::min(height, full_height - start_row); + width = std::min(width, full_width - start_col); + + return {start_row, start_col, height, width}; +} + +} // namespace kai::test diff --git a/test/common/matrix_portion.hpp b/test/common/matrix_portion.hpp new file mode 100644 index 0000000000000000000000000000000000000000..68193a402f2d093062b9eac3952b0e7fd02d7715 --- /dev/null +++ b/test/common/matrix_portion.hpp @@ -0,0 +1,68 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#include "test/common/rect.hpp" + +namespace kai::test { + +/// Portion of a matrix. +/// +/// This class is used to define the sub-matrix under test. +/// +/// This is the relative version of @ref Rect. +class MatrixPortion { +public: + /// Creates a new matrix portion. + /// + /// @param[in] start_row Starting row as the ratio to the height of the matrix. + /// @param[in] start_col Starting column as the ratio to the width of the matrix. + /// @param[in] height Portion height as the ratio to the height of the matrix. + /// @param[in] width Portion width as the ratio to the width of the matrix. + MatrixPortion(float start_row, float start_col, float height, float width); + + /// Gets the starting row as the ratio to the height of the matrix. + [[nodiscard]] float start_row() const; + + /// Gets the starting column as the ratio to the width of the matrix. + [[nodiscard]] float start_col() const; + + /// Gets the portion height as the ratio to the height of the matrix. + [[nodiscard]] float height() const; + + /// Gets the portion width as the ratio to the width of the matrix. + [[nodiscard]] float width() const; + + /// Computes the starting coordinate and the shape of the sub-matrix. + /// + /// Requirements: + /// + /// * The starting coordinate of the sub-matrix shall be aligned with the scheduling block boundary. + /// * If it is not the scheduling block at the right and/or bottom edge of the full matrix, the height and width + /// of the sub-matrix shall be rounded up to multiple of the scheduling block height and width. + /// * If it is the scheduling block at the right and/or bottom edge of the full matrix, the height and width + /// of the sub-matrix shall be the rounded up to the edge of the matrix. + /// + /// @param[in] full_height Matrix height. + /// @param[in] full_width Matrix width. + /// @param[in] scheduler_block_height Block height for scheduling purpose. + /// @param[in] scheduler_block_width Block width for scheduling purpose. + /// + /// @return The rectangular region of the matrix. + [[nodiscard]] Rect compute_portion( + size_t full_height, size_t full_width, size_t scheduler_block_height, size_t scheduler_block_width) const; + +private: + float _start_row; + float _start_col; + float _height; + float _width; +}; + +} // namespace kai::test diff --git a/test/common/rect.cpp b/test/common/rect.cpp new file mode 100644 index 0000000000000000000000000000000000000000..457ac2f3573021a6ff77076071578bbca1f1ec4e --- /dev/null +++ b/test/common/rect.cpp @@ -0,0 +1,41 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/common/rect.hpp" + +#include + +namespace kai::test { + +Rect::Rect(size_t start_row, size_t start_col, size_t height, size_t width) : + _start_row(start_row), _start_col(start_col), _height(height), _width(width) { +} + +size_t Rect::start_row() const { + return _start_row; +} + +size_t Rect::start_col() const { + return _start_col; +} + +size_t Rect::height() const { + return _height; +} + +size_t Rect::width() const { + return _width; +} + +size_t Rect::end_row() const { + return _start_row + _height; +} + +size_t Rect::end_col() const { + return _start_col + _width; +} + +} // namespace kai::test diff --git a/test/common/rect.hpp b/test/common/rect.hpp new file mode 100644 index 0000000000000000000000000000000000000000..92b033b486bdd44d7cb76680f46438c10274e4c5 --- /dev/null +++ b/test/common/rect.hpp @@ -0,0 +1,51 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +namespace kai::test { + +/// Rectangular region of a matrix. +/// +/// This is the absolute version of @ref MatrixPortion. +class Rect { +public: + /// Creates a new rectangular region of a matrix. + /// + /// @param[in] start_row Starting row index. + /// @param[in] start_col Starting column index. + /// @param[in] height Number of rows. + /// @param[in] width Number of columns. + Rect(size_t start_row, size_t start_col, size_t height, size_t width); + + /// Gets the starting row index. + [[nodiscard]] size_t start_row() const; + + /// Gets the starting column index. + [[nodiscard]] size_t start_col() const; + + /// Gets the number of rows. + [[nodiscard]] size_t height() const; + + /// Gets the number of columns. + [[nodiscard]] size_t width() const; + + /// Gets the end (exclusive) row index. + [[nodiscard]] size_t end_row() const; + + /// Gets the end (exclusive) column index. + [[nodiscard]] size_t end_col() const; + +private: + size_t _start_row; + size_t _start_col; + size_t _height; + size_t _width; +}; + +} // namespace kai::test