diff --git a/kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot.c b/kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot.c index ec43f6054569755945b50221ae8bcd4fb0336005..9a7aaf11cf1cc18e83248a555fb6a6fc5dff657e 100644 --- a/kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot.c +++ b/kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot.c @@ -99,7 +99,7 @@ void kai_run_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot( const void* B_ptr = rhs_packed; void* output_ptr = dst; - uint64_t flags = 0; + uint64_t flags = 2; __asm__ __volatile__( ".inst 0xd503477f // SMSTART ZA\n" diff --git a/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.c b/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.c index 9b9e3caf589687895edad88ebff0ebf5b25d4c99..627ab6884a57899716fe5a09e2189b9b10938175 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.c @@ -91,7 +91,7 @@ void kai_run_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla( const void* B_ptr = rhs_packed; void* output_ptr = dst; - uint64_t flags = 0; + uint64_t flags = 2; __asm__ __volatile__( ".inst 0xd503477f // SMSTART ZA\n" diff --git a/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.c b/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.c index 6eebd1e2baf71bbbaa4f876e1f8d4b65a4a8c80f..cf0608401658f78884c6bdb67d31ec7a57bd86cb 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.c @@ -92,7 +92,7 @@ void kai_run_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla( const void* B_ptr = rhs_packed; void* output_ptr = dst; - uint64_t flags = 0; + uint64_t flags = 2; __asm__ __volatile__( ".inst 0xd503477f // SMSTART ZA\n" diff --git a/test/reference/clamp.cpp b/test/reference/clamp.cpp index a4ba773cc432ce016c554b3f670c30c52bf2b52f..6a8e7433c7069f14df8067c2eb104dde82608afa 100644 --- a/test/reference/clamp.cpp +++ b/test/reference/clamp.cpp @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -11,11 +11,44 @@ #include #include +#include "kai/kai_common.h" +#include "test/common/float16.hpp" #include "test/common/memory.hpp" +#include "test/common/numeric_limits.hpp" #include "test/common/round.hpp" namespace kai::test { +template +std::tuple find_clamp_range(const void* src, size_t len, float ratio) { + KAI_ASSUME(ratio > 0.0F); + KAI_ASSUME(ratio <= 1.0F); + + T min_value = numeric_highest; + T max_value = numeric_lowest; + + for (size_t i = 0; i < len; ++i) { + const T value = read_array(src, i); + + min_value = std::min(min_value, value); + max_value = std::max(max_value, value); + } + + min_value = std::max(min_value, numeric_lowest); + max_value = std::min(max_value, numeric_highest); + + const T range = max_value - min_value; + const T reduction = static_cast(static_cast(range) * (1.0F - ratio) / 2); + + const T clamp_min_value = min_value + reduction; + const T clamp_max_value = max_value - reduction; + + return {clamp_min_value, clamp_max_value}; +} + +template std::tuple find_clamp_range(const void* src, size_t len, float ratio); +template std::tuple find_clamp_range(const void* src, size_t len, float ratio); + template std::vector clamp(const void* src, size_t len, T min_value, T max_value) { std::vector dst(round_up_division(len * size_in_bits, 8)); @@ -28,5 +61,6 @@ std::vector clamp(const void* src, size_t len, T min_value, T max_value } template std::vector clamp(const void* src, size_t len, float min_value, float max_value); +template std::vector clamp(const void* src, size_t len, Float16 min_value, Float16 max_value); } // namespace kai::test diff --git a/test/reference/clamp.hpp b/test/reference/clamp.hpp index 24d3ac6c74c0acb02c4742ccdeb307e92f7f06bd..b665917e7f6fbc9d7b92c6a49142415f2703d5a2 100644 --- a/test/reference/clamp.hpp +++ b/test/reference/clamp.hpp @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -8,10 +8,21 @@ #include #include +#include #include namespace kai::test { +/// Finds the clamping parameters to limit the dynamic range. +/// +/// @param[in] src The data buffer. +/// @param[in] len The number of values. +/// @param[in] ratio The ratio between the output dynamic range and the input dynamic range. +/// +/// @return The minimum value and the maximum value. +template +std::tuple find_clamp_range(const void* src, size_t len, float ratio); + /// Clamps the matrix. /// /// @param[in] src Data buffer of the source matrix. diff --git a/test/tests/matmul_clamp_f32_f32_f32p_test.cpp b/test/tests/matmul_clamp_f32_f32_f32p_test.cpp index 2f6fd131e1cfc6fae97cec5351d813722c84369d..be124f2212ed0d3e12d9bd4cf66f0e70027c141e 100644 --- a/test/tests/matmul_clamp_f32_f32_f32p_test.cpp +++ b/test/tests/matmul_clamp_f32_f32_f32p_test.cpp @@ -25,6 +25,7 @@ #include "test/common/data_type.hpp" #include "test/common/memory.hpp" #include "test/common/test_suite.hpp" +#include "test/reference/clamp.hpp" #include "test/reference/fill.hpp" #include "test/reference/matmul.hpp" @@ -89,10 +90,15 @@ TEST_P(MatMulTest_f32_f32_f32p, EndToEnd) // NOLINT(google-readability-avoid-un const auto ref_bias = fill_random(n, seed + 2); // Runs the reference implementation - const auto ref_dst = matmul( + const auto ref_dst_no_clamp = matmul( ref_lhs.data(), nullptr, nullptr, DataType::FP32, ref_rhs.data(), nullptr, nullptr, DataType::FP32, ref_bias.data(), nullptr, nullptr, DataType::FP32, DataType::FP32, m, n, k, false, false); + // Clamps the reference output. + const auto clamp_ratio = 0.8F; + const auto [clamp_min, clamp_max] = find_clamp_range(ref_dst_no_clamp.data(), m * n, clamp_ratio); + const auto ref_dst = clamp(ref_dst_no_clamp.data(), m * n, clamp_min, clamp_max); + // Run the RHS packing micro-kernel. const auto rhs_stride = n * sizeof(float); @@ -124,8 +130,8 @@ TEST_P(MatMulTest_f32_f32_f32p, EndToEnd) // NOLINT(google-readability-avoid-un std::vector imp_dst(imp_dst_size); ukernel_variant.interface.run_matmul( - m, n, k, ref_lhs.data(), 1, imp_packed_rhs->data(), reinterpret_cast(imp_dst.data()), 1, 1, - std::numeric_limits::lowest(), std::numeric_limits::max()); + m, n, k, ref_lhs.data(), 1, imp_packed_rhs->data(), reinterpret_cast(imp_dst.data()), 1, 1, clamp_min, + clamp_max); // Compare the output of the micro-kernels against the output of the reference implementation. for (size_t y = 0; y < m; ++y) { diff --git a/test/tests/matmul_test.cpp b/test/tests/matmul_test.cpp index dd4b0dd8eef53e327dc58cc9106666c79d3a521d..c57b34890a16598e46c1fecb285597275016ef13 100644 --- a/test/tests/matmul_test.cpp +++ b/test/tests/matmul_test.cpp @@ -25,10 +25,12 @@ #include "test/common/cpu_info.hpp" #include "test/common/data_format.hpp" #include "test/common/data_type.hpp" +#include "test/common/float16.hpp" #include "test/common/matmul_test_common.hpp" #include "test/common/matrix_portion.hpp" #include "test/common/printer.hpp" #include "test/common/sme.hpp" +#include "test/reference/clamp.hpp" #include "test/reference/fill.hpp" #include "test/reference/pack.hpp" @@ -360,6 +362,8 @@ protected: std::vector rhs_t{}; ///< Transposed RHS matrix. std::vector ref_packed_rhs{}; ///< Reference packed RHS. std::vector ref_dst{}; ///< Reference output. + float clamp_min{}; ///< Minimum output value. + float clamp_max{}; ///< Maximum output value. }; /// Gets the test data for the current test case. @@ -427,6 +431,37 @@ protected: method.dst_format.data_type(), // info.m, info.n, info.k, false, false); + float clamp_min = 0.0F; + float clamp_max = 0.0F; + constexpr float clamp_ratio = 0.8F; + + switch (method.dst_format.data_type()) { + case DataType::FP32: { + const auto [min_value, max_value] = + find_clamp_range(ref_dst.data(), info.m * info.n, clamp_ratio); + ref_dst = clamp(ref_dst.data(), info.m * info.n, min_value, max_value); + + clamp_min = min_value; + clamp_max = max_value; + + break; + } + + case DataType::FP16: { + const auto [min_value, max_value] = + find_clamp_range(ref_dst.data(), info.m * info.n, clamp_ratio); + ref_dst = clamp(ref_dst.data(), info.m * info.n, min_value, max_value); + + clamp_min = static_cast(min_value); + clamp_max = static_cast(max_value); + + break; + } + + default: + KAI_ERROR("Unsupported data type!"); + } + const auto& data = _data[data_id] = { .lhs = std::move(lhs), .ref_packed_lhs = std::move(ref_packed_lhs), @@ -436,6 +471,8 @@ protected: .rhs_t = std::move(rhs_t), .ref_packed_rhs = std::move(packed_rhs), .ref_dst = std::move(ref_dst), + .clamp_min = clamp_min, + .clamp_max = clamp_max, }; return data; @@ -730,8 +767,7 @@ TEST_P(MatMulTest, Output) { method.main_kernel( rect.height(), rect.width(), info.k, lhs_data + lhs_offset, rhs_data + rhs_offset, bias_data + bias_offset, - dst.data() + dst_offset, lhs_stride, rhs_stride, dst_stride, -std::numeric_limits::infinity(), - std::numeric_limits::infinity()); + dst.data() + dst_offset, lhs_stride, rhs_stride, dst_stride, data.clamp_min, data.clamp_max); DefaultMismatchHandler handler(0, 0.1, 0, 0.05); const auto success = compare(dst.data(), data.ref_dst.data(), method.dst_format, info.m, info.n, rect, handler);