From 3d43a39b20a3b0afbb1e03ab3d521d4d17b2dc99 Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Thu, 15 Aug 2024 11:03:21 +0100 Subject: [PATCH 1/3] Add tests for Int4 kernels * Tested kernels: - lhs_quant_pack_qai8dxp_f32 - rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 - matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm * Only entire matrix is tested - support for testing arbitrary portion of the output will come later. * Update test framework to accommodate the new tests. - The new test no longer relies on DataFormat concept. Now the packing function for each format is explicitly called. Signed-off-by: Viet-Hoa Do --- CMakeLists.txt | 1 + test/common/int4.hpp | 4 +- test/common/round.cpp | 14 + test/common/round.hpp | 27 ++ test/common/type_traits.hpp | 10 +- test/reference/cast.cpp | 11 + test/reference/cast.hpp | 9 + test/reference/fill.cpp | 7 + test/reference/fill.hpp | 11 + test/reference/matmul.cpp | 90 ++++- test/reference/matmul.hpp | 10 + test/reference/pack.cpp | 336 +++++++++--------- test/reference/pack.hpp | 112 ++++++ test/reference/quantize.cpp | 190 +++++++--- test/reference/quantize.hpp | 180 ++++++++-- .../matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp | 96 +++++ 16 files changed, 841 insertions(+), 267 deletions(-) create mode 100644 test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index bf177358..e51d0e63 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -187,6 +187,7 @@ if(KLEIDIAI_BUILD_TESTS) add_executable(kleidiai_test test/tests/matmul_test.cpp + test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp ) target_link_libraries(kleidiai_test diff --git a/test/common/int4.hpp b/test/common/int4.hpp index 7f0e93a0..0ebde82f 100644 --- a/test/common/int4.hpp +++ b/test/common/int4.hpp @@ -79,10 +79,10 @@ public: Int4& operator=(int value); /// Conversion operator. - explicit operator int32_t() const; + operator int32_t() const; /// Conversion operator. - explicit operator float() const; + operator float() const; /// Addition operator. [[nodiscard]] Int4 operator+(Int4 rhs) const; diff --git a/test/common/round.cpp b/test/common/round.cpp index bcb1cb37..08f6c7ed 100644 --- a/test/common/round.cpp +++ b/test/common/round.cpp @@ -25,10 +25,24 @@ size_t round_to_nearest_even_usize(float value) { return rounded; } +template <> +int32_t round_to_nearest_even(float value) { + return round_to_nearest_even_i32(value); +} + +template <> +size_t round_to_nearest_even(float value) { + return round_to_nearest_even_usize(value); +} + size_t round_up_multiple(size_t a, size_t b) { return ((a + b - 1) / b) * b; } +size_t round_up_division(size_t a, size_t b) { + return (a + b - 1) / b; +} + size_t round_down_multiple(size_t a, size_t b) { return (a / b) * b; } diff --git a/test/common/round.hpp b/test/common/round.hpp index 7503936f..7eb792cb 100644 --- a/test/common/round.hpp +++ b/test/common/round.hpp @@ -43,6 +43,25 @@ int32_t round_to_nearest_even_i32(float value); /// @return The rounded value. size_t round_to_nearest_even_usize(float value); +/// Rounds the specified value to nearest with tie to even. +/// +/// For example: +/// +/// * 0.4 is rounded to 0. +/// * 0.5 is rounded to 0 (as 0 is the nearest even value). +/// * 0.6 is rounded to 1. +/// * 1.4 is rounded to 1. +/// * 1.5 is rounded to 2 (as 2 is the nearest even value). +/// * 1.6 is rounded to 2. +/// +/// @tparam T The target data type (must be integer). +/// +/// @param[in] value Value to be rounded. +/// +/// @return The rounded value. +template +T round_to_nearest_even(float value); + /// Rounds up the input value to the multiple of the unit value. /// /// @param[in] a Input value. @@ -51,6 +70,14 @@ size_t round_to_nearest_even_usize(float value); /// @return The rounded value. size_t round_up_multiple(size_t a, size_t b); +/// Divides and rounds up. +/// +/// @param[in] a The dividend. +/// @param[in] b The divisor. +/// +/// @return The division of a to b rounding up. +size_t round_up_division(size_t a, size_t b); + /// Rounds down the input value to the multiple of the unit value. /// /// @param[in] a Input value. diff --git a/test/common/type_traits.hpp b/test/common/type_traits.hpp index 00559fc0..84601fba 100644 --- a/test/common/type_traits.hpp +++ b/test/common/type_traits.hpp @@ -9,6 +9,8 @@ #include #include +#include "test/common/float16.hpp" + namespace kai::test { class UInt4; @@ -25,7 +27,7 @@ inline constexpr bool is_unsigned = true; /// `true` if `T` is unsigned numeric type. template <> -inline constexpr bool is_unsigned = true; +inline constexpr bool is_unsigned = false; /// `true` if `T` is unsigned numeric type. template <> @@ -41,7 +43,7 @@ inline constexpr bool is_signed = false; /// `true` if `T` is signed numeric type. template <> -inline constexpr bool is_signed = false; +inline constexpr bool is_signed = true; /// `true` if `T` is signed numeric type. template <> @@ -67,6 +69,10 @@ inline constexpr bool is_integral = false; template inline constexpr bool is_floating_point = std::is_floating_point_v; +/// `true` if `T` is floating-point type. +template <> +inline constexpr bool is_floating_point = true; + /// `true` if `T` is floating-point type. template <> inline constexpr bool is_floating_point = true; diff --git a/test/reference/cast.cpp b/test/reference/cast.cpp index 21486c05..2a3393c5 100644 --- a/test/reference/cast.cpp +++ b/test/reference/cast.cpp @@ -12,6 +12,7 @@ #include "test/common/bfloat16.hpp" #include "test/common/data_type.hpp" #include "test/common/memory.hpp" +#include "test/common/round.hpp" namespace kai::test { @@ -41,4 +42,14 @@ std::vector cast(const void* src, kai::test::DataType src_dt, DataType KAI_ERROR("Unsupported cast data type!"); } +std::vector cast_qsu4_qsi4(const void* src, size_t length) { + std::vector dst(round_up_division(length, 2)); + + for (size_t i = 0; i < length; ++i) { + write_array(dst.data(), i, static_cast(static_cast(read_array(src, i)) + 8)); + } + + return dst; +} + } // namespace kai::test diff --git a/test/reference/cast.hpp b/test/reference/cast.hpp index 1ef3b9db..deb0bce9 100644 --- a/test/reference/cast.hpp +++ b/test/reference/cast.hpp @@ -25,4 +25,13 @@ namespace kai::test { /// @return The result matrix containing data in the destination data type. std::vector cast(const void* src, DataType src_dt, DataType dst_dt, size_t height, size_t width); +/// Converts each element of the source data from 4-bit signed symmetric quantized +/// to 4-bit unsigned symmetric quantized. +/// +/// @param[in] src The source data. +/// @param[in] length The number of elements. +/// +/// @return A new data buffer with converted values. +std::vector cast_qsu4_qsi4(const void* src, size_t length); + } // namespace kai::test diff --git a/test/reference/fill.cpp b/test/reference/fill.cpp index d58ab6ce..49b16987 100644 --- a/test/reference/fill.cpp +++ b/test/reference/fill.cpp @@ -118,4 +118,11 @@ std::vector fill_matrix_random(size_t height, size_t width, const DataF } } +template +std::vector fill_random(size_t length, uint64_t seed) { + return fill_matrix_random_raw(1, length, seed); +} + +template std::vector fill_random(size_t length, uint64_t seed); + } // namespace kai::test diff --git a/test/reference/fill.hpp b/test/reference/fill.hpp index 09138145..80093952 100644 --- a/test/reference/fill.hpp +++ b/test/reference/fill.hpp @@ -24,4 +24,15 @@ class DataFormat; /// @return The data buffer for the matrix. std::vector fill_matrix_random(size_t height, size_t width, const DataFormat& format, uint64_t seed); +/// Creates a new data buffer filled with random data. +/// +/// @tparam Value The data type. +/// +/// @param[in] length The number of elements. +/// @param[in] seed The random seed. +/// +/// @return The data buffer. +template +std::vector fill_random(size_t length, uint64_t seed); + } // namespace kai::test diff --git a/test/reference/matmul.cpp b/test/reference/matmul.cpp index 984982a2..ea764ca5 100644 --- a/test/reference/matmul.cpp +++ b/test/reference/matmul.cpp @@ -16,6 +16,7 @@ #include "test/common/float16.hpp" #include "test/common/int4.hpp" #include "test/common/memory.hpp" +#include "test/common/round.hpp" #include "test/reference/binary_elementwise.hpp" #include "test/reference/cast.hpp" #include "test/reference/pack.hpp" @@ -125,8 +126,10 @@ std::vector matmul_pack_rhs( } std::vector matmul( - const void* lhs, const void* lhs_scales, const void* lhs_zero_points, DataType lhs_dt, // - const void* rhs, const void* rhs_scales, const void* rhs_zero_points, DataType rhs_dt, // + const void* lhs, [[maybe_unused]] const void* lhs_scales, [[maybe_unused]] const void* lhs_zero_points, + DataType lhs_dt, // + const void* rhs, [[maybe_unused]] const void* rhs_scales, [[maybe_unused]] const void* rhs_zero_points, + DataType rhs_dt, // const void* bias, const void* bias_scales, const void* bias_zero_points, DataType bias_dt, // DataType dst_dt, // size_t m, size_t n, size_t k, // @@ -142,18 +145,6 @@ std::vector matmul( std::vector tmp_dst; std::vector tmp_bias; - if (data_type_is_quantized(lhs_dt)) { - tmp_lhs = dequantize( - lhs, lhs_scales, lhs_zero_points, lhs_dt, DataType::FP32, QuantizationMethod::PER_MATRIX, lhs_h, lhs_w); - lhs = tmp_lhs.data(); - } - - if (data_type_is_quantized(rhs_dt)) { - tmp_rhs = dequantize( - rhs, rhs_scales, rhs_zero_points, rhs_dt, DataType::FP32, QuantizationMethod::PER_ROW, rhs_h, rhs_w); - rhs = tmp_rhs.data(); - } - if (lhs_dt != dst_dt) { tmp_lhs = cast(lhs, lhs_dt, dst_dt, lhs_h, lhs_w); lhs = tmp_lhs.data(); @@ -193,4 +184,75 @@ std::vector matmul( return tmp_dst; } +template < + typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale, + typename RhsZeroPoint, typename Bias, typename IntAcc, typename DstData> +std::vector matmul_clamp_nt_t( + size_t m, size_t n, size_t k, // + const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // + const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // + const void* biases, // + DstData min_value, DstData max_value) { + const auto lhs_num_quant_per_row = round_up_division(k, lhs_quant_width); + const auto rhs_num_quant_per_row = round_up_division(k, rhs_quant_width); + + std::vector dst(m * n * sizeof(DstData)); + + const auto* lhs_scales_ptr = reinterpret_cast(lhs_scales); + const auto* rhs_scales_ptr = reinterpret_cast(rhs_scales); + const auto* lhs_zero_points_ptr = reinterpret_cast(lhs_zero_points); + const auto* rhs_zero_points_ptr = reinterpret_cast(rhs_zero_points); + const auto* biases_ptr = reinterpret_cast(biases); + auto* dst_ptr = reinterpret_cast(dst.data()); + + for (size_t y = 0; y < m; ++y) { + for (size_t x = 0; x < n; ++x) { + DstData acc = 0; + + for (size_t i = 0; i < k; ++i) { + const auto lhs_value = read_array(lhs_data, y * k + i); + const auto lhs_scale = lhs_scales_ptr[y * lhs_num_quant_per_row + i / lhs_quant_width]; + const auto lhs_zero_point = lhs_zero_points_ptr != nullptr + ? lhs_zero_points_ptr[y * lhs_num_quant_per_row + i / lhs_quant_width] + : 0; + + const auto rhs_value = read_array(rhs_data, x * k + i); + const auto rhs_scale = rhs_scales_ptr[x * rhs_num_quant_per_row + i / rhs_quant_width]; + const auto rhs_zero_point = rhs_zero_points_ptr != nullptr + ? rhs_zero_points_ptr[y * rhs_num_quant_per_row + i / rhs_quant_width] + : 0; + + acc += static_cast( + (static_cast(lhs_value) + static_cast(lhs_zero_point)) * + (static_cast(rhs_value) + static_cast(rhs_zero_point))) * + static_cast(lhs_scale) * static_cast(rhs_scale); + } + + if (biases_ptr != nullptr) { + acc += static_cast(biases_ptr[x]); + } + + acc = std::clamp(acc, min_value, max_value); + dst_ptr[y * n + x] = acc; + } + } + + return dst; +} + +template std::vector matmul_clamp_nt_t( + size_t m, size_t n, size_t k, // + const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // + const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // + const void* biases, // + float min_value, float max_value); + +template std::vector +matmul_clamp_nt_t( + size_t m, size_t n, size_t k, // + const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // + const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // + const void* biases, // + float min_value, float max_value); + } // namespace kai::test diff --git a/test/reference/matmul.hpp b/test/reference/matmul.hpp index e71f17c2..68b5de90 100644 --- a/test/reference/matmul.hpp +++ b/test/reference/matmul.hpp @@ -65,4 +65,14 @@ std::vector matmul( size_t m, size_t n, size_t k, // bool lhs_transposed, bool rhs_transposed); +template < + typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale, + typename RhsZeroPoint, typename Bias, typename IntAcc, typename DstData> +std::vector matmul_clamp_nt_t( + size_t m, size_t n, size_t k, // + const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // + const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // + const void* biases, // + DstData min_value, DstData max_value); + } // namespace kai::test diff --git a/test/reference/pack.cpp b/test/reference/pack.cpp index 72115779..0ff3904a 100644 --- a/test/reference/pack.cpp +++ b/test/reference/pack.cpp @@ -10,13 +10,14 @@ #include #include #include -#include #include #include #include "kai/kai_common.h" #include "test/common/data_format.hpp" #include "test/common/data_type.hpp" +#include "test/common/float16.hpp" +#include "test/common/memory.hpp" #include "test/common/round.hpp" #include "test/reference/quantize.hpp" @@ -114,213 +115,224 @@ std::vector pack_bias_per_row( return dst; } -/// Packs the matrix from raw to quantized format. -template -std::vector pack_quant_per_row( - const void* src, size_t height, size_t width, size_t block_height, size_t block_width) { - const auto num_groups = (height + block_height - 1) / block_height; - const auto group_num_blocks = (width + block_width - 1) / block_width; +} // namespace - const auto group_zero_points_bytes = block_height * sizeof(ZeroPoint); - const auto group_scales_bytes = block_height * sizeof(Scale); - const auto block_data_bytes = block_height * block_width * sizeof(Output); - const auto group_bytes = group_zero_points_bytes + group_num_blocks * block_data_bytes + group_scales_bytes; - const auto dst_bytes = num_groups * group_bytes; +std::vector pack( + const DataFormat& dst_format, const void* src, [[maybe_unused]] const void* scales, const void* zero_points, + const DataFormat& src_format, size_t height, size_t width) { + const auto dst_dt = dst_format.data_type(); + const auto dst_qf = dst_format.pack_format(); + const auto src_dt = src_format.data_type(); + const auto src_qf = src_format.pack_format(); - std::vector dst; - dst.resize(dst_bytes); + const auto block_height = dst_format.actual_block_height(height); + const auto block_width = dst_format.actual_block_width(width); + const auto subblock_height = dst_format.actual_subblock_height(height); + const auto subblock_width = dst_format.actual_subblock_width(width); - const auto* src_ptr = reinterpret_cast(src); - auto* dst_ptr = dst.data(); + if (src_qf == DataFormat::PackFormat::NONE && dst_qf == DataFormat::PackFormat::BIAS_PER_ROW) { + KAI_ASSUME(src_dt == dst_dt); + + const auto data_esize = data_type_size_in_bits(dst_dt); + const auto zero_point_esize = data_type_size_in_bits(dst_format.zero_point_data_type()); + + if (data_esize % 8 == 0 && zero_point_esize % 8 == 0) { + return pack_bias_per_row( + data_esize / 8, zero_point_esize / 8, src, zero_points, height, width, block_height, block_width, + subblock_height, subblock_width); + } + } - std::vector scales; - scales.resize(block_height); + if (src_qf == DataFormat::PackFormat::NONE && dst_qf == DataFormat::PackFormat::NONE) { + KAI_ASSUME(src_dt == dst_dt); - std::vector zero_points; - zero_points.resize(block_height); + const auto data_esize = data_type_size_in_bits(dst_dt); - for (size_t group_no = 0; group_no < num_groups; ++group_no) { - // Finds the range of values and calculates the quantization information. - for (size_t y = 0; y < block_height; ++y) { - auto min_value = std::numeric_limits::max(); - auto max_value = std::numeric_limits::lowest(); + if (data_esize % 8 == 0) { + return pack_block( + src, data_esize / 8, height, width, block_height, block_width, subblock_height, subblock_width); + } + } - for (size_t x = 0; x < width; ++x) { - const auto value = src_ptr[(group_no * block_height + y) * width + x]; + KAI_ERROR("Unsupported operation!"); +} - if (value < min_value) { - min_value = value; - } +template +std::vector pack_data_scales( + const void* data, const void* scales, size_t height, size_t width, size_t quant_width) { + KAI_ASSUME_IF(size_in_bits < 8, quant_width % (8 / size_in_bits) == 0); + KAI_ASSUME_IF(size_in_bits < 8, width % (8 / size_in_bits) == 0); - if (value > max_value) { - max_value = value; - } - } + const auto num_quant_packets_x = round_up_multiple(width, quant_width) / quant_width; - std::tie(scales[y], zero_points[y]) = get_qai8_scale_zero_point_from_range(min_value, max_value); - } + const auto data_bytes = height * width * size_in_bits / 8; + const auto scales_bytes = height * num_quant_packets_x * sizeof(Scale); - // Packs the zero points. - memcpy(dst_ptr, zero_points.data(), group_zero_points_bytes); - dst_ptr += group_zero_points_bytes; - - // Quantizes and packs the data. - for (size_t x_block = 0; x_block < group_num_blocks; ++x_block) { - for (size_t block_y = 0; block_y < block_height; ++block_y) { - for (size_t block_x = 0; block_x < block_width; ++block_x) { - const auto value = - src_ptr[(group_no * block_height + block_y) * width + x_block * block_width + block_x]; - const auto qvalue = quantize_i8_fp32(value, scales[block_y], zero_points[block_y]); - *reinterpret_cast(dst_ptr) = qvalue; - ++dst_ptr; - } + std::vector dst(data_bytes + scales_bytes); + + const auto* scales_ptr = reinterpret_cast(scales); + auto* dst_ptr = dst.data(); + + for (size_t y = 0; y < height; ++y) { + for (size_t x_quant = 0; x_quant < width; x_quant += quant_width) { + write_array(dst_ptr, 0, *scales_ptr); + dst_ptr += sizeof(Scale); + ++scales_ptr; + + const auto len = std::min(x_quant + quant_width, width) - x_quant; + + for (size_t x_element = 0; x_element < len; ++x_element) { + const auto x = x_quant + x_element; + write_array(dst_ptr, x_element, read_array(data, y * width + x)); } - } - // Packs the scales. - memcpy(dst_ptr, scales.data(), group_scales_bytes); - dst_ptr += group_scales_bytes; + dst_ptr += len * size_in_bits / 8; + } } - KAI_ASSERT(reinterpret_cast(dst_ptr) - reinterpret_cast(dst.data()) == dst_bytes); + KAI_ASSERT(dst_ptr == &*dst.end()); return dst; } -/// Packs the matrix with per-row quantized format. -/// -/// The source matrix is per-row quantized with separate quantization scale and zero-points data buffer. -/// The destination data is per-row quantized with blocking and embedded quantization information. -std::vector pack_per_row_qs4( - const void* src, const void* scales, const void* zero_points, size_t height, size_t width, size_t block_height, - size_t block_width, size_t subblock_height, size_t subblock_width) { - // Number of elements in a sub-block in vertical and horizontal axes. - const auto num_element_rows = subblock_height; - const auto num_element_cols = subblock_width; - const auto src_element_row_stride = width / 2; - - // Number of sub-blocks in a block in vertical and horizontal axes. - const auto num_subblock_rows = block_height / subblock_height; - const auto num_subblock_cols = block_width / subblock_width; - const auto src_subblock_col_stride = subblock_width / 4; - const auto src_subblock_row_stride = subblock_height * width / 2; - - // Number of blocks in the matrix in vertical and horizontal axes. - const auto num_block_rows = (height + block_height - 1) / block_height; - const auto num_block_cols = (width + block_width - 1) / block_width; - const auto src_block_col_stride = block_width / 2; - const auto src_block_row_stride = block_height * width / 2; - - const auto dst_block_row_scales_bytes = block_height * sizeof(float); - const auto dst_block_row_zero_points_bytes = block_height * sizeof(int32_t); - const auto dst_block_row_data_bytes = num_block_cols * block_height * block_width / 2; - const auto dst_bytes = - num_block_rows * (dst_block_row_zero_points_bytes + dst_block_row_data_bytes + dst_block_row_scales_bytes); +template +std::vector pack_data_scales_interleave_block( + const void* data, const void* scales, size_t height, size_t width, size_t quant_width) { + KAI_ASSUME_IF(size_in_bits < 8, quant_width % (8 / size_in_bits) == 0); + KAI_ASSUME_IF(size_in_bits < 8, width % (8 / size_in_bits) == 0); + KAI_ASSUME(width % quant_width == 0); + KAI_ASSUME(quant_width % 2 == 0); - std::vector dst; - dst.resize(dst_bytes); + const auto num_quant_packets_x = round_up_multiple(width, quant_width) / quant_width; - const auto* src_ptr = reinterpret_cast(src); - const auto* scales_ptr = reinterpret_cast(scales); - const auto* zero_points_ptr = reinterpret_cast(zero_points); + const auto data_bytes = height * width * size_in_bits / 8; + const auto scales_bytes = height * num_quant_packets_x * sizeof(Scale); + + std::vector dst(data_bytes + scales_bytes); + + const auto* scales_ptr = reinterpret_cast(scales); auto* dst_ptr = dst.data(); - for (size_t block_row = 0; block_row < num_block_rows; ++block_row) { - if (zero_points_ptr != nullptr) { - memcpy(dst_ptr, zero_points_ptr + block_row * block_height, dst_block_row_zero_points_bytes); - } + for (size_t y = 0; y < height; ++y) { + for (size_t x_quant = 0; x_quant < width; x_quant += quant_width) { + write_array(dst_ptr, 0, *scales_ptr); + dst_ptr += sizeof(Scale); + ++scales_ptr; - dst_ptr += dst_block_row_zero_points_bytes; - - for (size_t block_col = 0; block_col < num_block_cols; ++block_col) { - for (size_t subblock_col = 0; subblock_col < num_subblock_cols; ++subblock_col) { - for (size_t subblock_row = 0; subblock_row < num_subblock_rows; ++subblock_row) { - for (size_t element_col = 0; element_col < num_element_cols / 4; ++element_col) { - for (size_t element_row = 0; element_row < num_element_rows; ++element_row) { - const auto byte_lo = src_ptr[ // - block_row * src_block_row_stride + block_col * src_block_col_stride + - subblock_row * src_subblock_row_stride + subblock_col * src_subblock_col_stride + - element_row * src_element_row_stride + element_col]; - const auto byte_hi = src_ptr[ // - block_row * src_block_row_stride + block_col * src_block_col_stride + - subblock_row * src_subblock_row_stride + subblock_col * src_subblock_col_stride + - element_row * src_element_row_stride + element_col + block_width / 4]; - - const auto packed_byte0 = (byte_lo & 0x0F) | (byte_hi << 4); - const auto packed_byte1 = (byte_lo >> 4) | (byte_hi & 0xF0); - - dst_ptr[0] = packed_byte0; // ^ 0x88; - dst_ptr[1] = packed_byte1; // ^ 0x88; - dst_ptr += 2; - } - } - } + for (size_t x_element = 0; x_element < quant_width; ++x_element) { + const auto x = x_quant + x_element / 2 + (x_element % 2 != 0 ? quant_width / 2 : 0); + write_array(dst_ptr, x_element, read_array(data, y * width + x)); } - } - if (scales_ptr != nullptr) { - memcpy(dst_ptr, scales_ptr + block_row * block_height, dst_block_row_scales_bytes); + dst_ptr += quant_width * size_in_bits / 8; } - dst_ptr += dst_block_row_scales_bytes; } - KAI_ASSERT(reinterpret_cast(dst_ptr) - reinterpret_cast(dst.data()) == dst_bytes); + KAI_ASSERT(dst_ptr == &*dst.end()); return dst; } -} // namespace +template std::vector pack_data_scales_interleave_block( + const void* data, const void* scales, size_t height, size_t width, size_t quant_width); -std::vector pack( - const DataFormat& dst_format, const void* src, const void* scales, const void* zero_points, - const DataFormat& src_format, size_t height, size_t width) { - const auto dst_dt = dst_format.data_type(); - const auto dst_qf = dst_format.pack_format(); - const auto src_dt = src_format.data_type(); - const auto src_qf = src_format.pack_format(); +template +std::vector pack_block_data_zero_points_scale_bias( + const void* data, const void* zero_points, const void* scales, const void* biases, size_t height, size_t width, + size_t quant_height, size_t quant_width, size_t block_height, size_t block_width, size_t interleave_x_blocks) { + if (quant_width == width) { + quant_width = round_up_multiple(quant_width, block_width); + } - const auto block_height = dst_format.actual_block_height(height); - const auto block_width = dst_format.actual_block_width(width); - const auto subblock_height = dst_format.actual_subblock_height(height); - const auto subblock_width = dst_format.actual_subblock_width(width); + KAI_ASSERT(quant_height == block_height); + KAI_ASSERT(quant_width % block_width == 0); - if (src_qf == DataFormat::PackFormat::NONE && dst_qf == DataFormat::PackFormat::QUANTIZE_PER_ROW) { - if (dst_dt == DataType::QAI8 && src_dt == DataType::FP32 && dst_format.scale_data_type() == DataType::FP32 && - dst_format.zero_point_data_type() == DataType::I32) { - return pack_quant_per_row(src, height, width, block_height, block_width); - } else if ( - dst_dt == DataType::QSI4 && src_dt == DataType::QSU4 && dst_format.scale_data_type() == DataType::FP32 && - dst_format.zero_point_data_type() == DataType::I32) { - return pack_per_row_qs4( - src, scales, zero_points, height, width, block_height, block_width, subblock_height, subblock_width); - } + if (interleave_x_blocks == 0) { + interleave_x_blocks = quant_width / block_width; } - if (src_qf == DataFormat::PackFormat::NONE && dst_qf == DataFormat::PackFormat::BIAS_PER_ROW) { - KAI_ASSUME(src_dt == dst_dt); + const auto has_zero_points = zero_points != nullptr; + const auto has_biases = biases != nullptr; - const auto data_esize = data_type_size_in_bits(dst_dt); - const auto zero_point_esize = data_type_size_in_bits(dst_format.zero_point_data_type()); + const auto num_quant_packets_y = round_up_division(height, quant_height); + const auto num_quant_packets_x = round_up_division(width, quant_width); - if (data_esize % 8 == 0 && zero_point_esize % 8 == 0) { - return pack_bias_per_row( - data_esize / 8, zero_point_esize / 8, src, zero_points, height, width, block_height, block_width, - subblock_height, subblock_width); - } - } + const auto quant_packet_data_bytes = quant_height * quant_width * size_in_bits / 8; + const auto quant_packet_zero_points_bytes = has_zero_points ? quant_height * sizeof(ZeroPoint) : 0; + const auto quant_packet_scales_bytes = quant_height * sizeof(Scale); + const auto quant_packet_bytes = + quant_packet_zero_points_bytes + quant_packet_data_bytes + quant_packet_scales_bytes; - if (src_qf == DataFormat::PackFormat::NONE && dst_qf == DataFormat::PackFormat::NONE) { - KAI_ASSUME(src_dt == dst_dt); + const auto num_quant_packets_per_row = round_up_division(width, quant_width); + const auto biases_bytes = has_biases ? height * sizeof(Bias) : 0; - const auto data_esize = data_type_size_in_bits(dst_dt); + const auto dst_bytes = num_quant_packets_y * num_quant_packets_x * quant_packet_bytes + biases_bytes; + std::vector dst(dst_bytes); - if (data_esize % 8 == 0) { - return pack_block( - src, data_esize / 8, height, width, block_height, block_width, subblock_height, subblock_width); + const auto* zero_points_ptr = reinterpret_cast(zero_points); + const auto* scales_ptr = reinterpret_cast(scales); + const auto* biases_ptr = reinterpret_cast(biases); + auto* dst_ptr = dst.data(); + + for (size_t y_quant = 0; y_quant < height; y_quant += quant_height) { + for (size_t x_quant = 0; x_quant < width; x_quant += quant_width) { + size_t dst_index = 0; + + // Packs the data. + for (size_t y_pack = 0; y_pack < quant_height; y_pack += block_height) { + for (size_t x_pack = 0; x_pack < block_width * interleave_x_blocks; x_pack += block_width) { + for (size_t y_element = 0; y_element < block_height; ++y_element) { + for (size_t x_element = 0; x_element < block_width; ++x_element) { + for (size_t x_interleave = 0; x_interleave < quant_width; + x_interleave += block_width * interleave_x_blocks) { + const auto y = y_quant + y_pack + y_element; + const auto x = x_quant + x_pack + x_element + x_interleave; + + if (y < height && x < width) { + write_array(dst_ptr, dst_index, read_array(data, y * width + x)); + } + + ++dst_index; + } + } + } + } + } + + dst_ptr += dst_index * size_in_bits / 8; + + // Packs the zero points. + if (has_zero_points) { + for (size_t y_element = 0; y_element < quant_height; ++y_element) { + const auto y = y_quant + y_element; + const auto x = x_quant / quant_width; + memcpy(dst_ptr, &zero_points_ptr[y * num_quant_packets_per_row + x], sizeof(ZeroPoint)); + dst_ptr += sizeof(ZeroPoint); + } + } + + // Packs the scales. + for (size_t y_element = 0; y_element < quant_height; ++y_element) { + const auto y = y_quant + y_element; + const auto x = x_quant / quant_width; + memcpy(dst_ptr, &scales_ptr[y * num_quant_packets_per_row + x], sizeof(Scale)); + dst_ptr += sizeof(Scale); + } + } + + // Packs the biases. + if (has_biases) { + for (size_t y_element = 0; y_element < quant_height; ++y_element) { + const auto y = y_quant + y_element; + memcpy(dst_ptr, &biases_ptr[y], sizeof(Bias)); + dst_ptr += sizeof(Bias); + } } } - KAI_ERROR("Unsupported operation!"); + KAI_ASSERT(dst_ptr == &*dst.end()); + + return dst; } } // namespace kai::test diff --git a/test/reference/pack.hpp b/test/reference/pack.hpp index 43362009..60f19591 100644 --- a/test/reference/pack.hpp +++ b/test/reference/pack.hpp @@ -8,6 +8,7 @@ #include #include +#include #include namespace kai::test { @@ -25,4 +26,115 @@ std::vector pack( const DataFormat& dst_format, const void* src, const void* scales, const void* zero_points, const DataFormat& src_format, size_t height, size_t width); +/// Packs the quantized data and the quantization scale into a single buffer. +/// +/// ``` +/// Quantized data matrix: +/// +/// --->|-----------------|<--- Quantization block width +/// | | +/// +-----------------+-----------------+----- ... +/// | q00 q01 q02 q03 | q04 q05 q06 q07 | ........ +/// | q10 q11 q12 q13 | q14 q15 q16 q17 | ........ +/// | q20 q21 q22 q23 | q24 q25 q26 q27 | ........ +/// | q30 q31 q32 q33 | q34 q35 q36 q37 | ........ +/// | ............... | ............... | ........ +/// : ............... : ............... : ........ +/// +/// Quantization scale matrix: +/// +/// +-----+-----+-- ... +/// | s00 | s01 | ..... +/// | s10 | s11 | ..... +/// | s20 | s21 | ..... +/// | s30 | s31 | ..... +/// | ... | ... | ..... +/// : ... : ... : ..... +/// ``` +/// +/// The packed data has each quantized block row followed by the corresponding quantization scale. +/// +/// ``` +/// Packed data: +/// +/// +-----------------+-----+-----------------+-----+----- ... +/// | q00 q01 q02 q03 | s00 | q04 q05 q06 q07 | s01 | ........ +/// | q10 q11 q12 q13 | s10 | q14 q15 q16 q17 | s11 | ........ +/// | q20 q21 q22 q23 | s20 | q24 q25 q26 q27 | s21 | ........ +/// | q30 q31 q32 q33 | s30 | q34 q35 q36 q37 | s31 | ........ +/// | ............... | ... | ............... | ... | ........ +/// : ............... : ... : ............... : ... : ........ +/// ``` +/// +/// @tparam Data The data type of the quantized value. +/// @tparam Scale The data type of the quantization scale. +/// +/// @param[in] data The quantized data. +/// @param[in] scales The quantization scales. +/// @param[in] height The number of rows. +/// @param[in] width The number of columns. +/// @param[in] quant_width The number of columns of the quantization block. +/// +/// @return The packed data buffer. +template +std::vector pack_data_scales( + const void* data, const void* scales, size_t height, size_t width, size_t quant_width); + +/// Packs the quantized data and the quantization scale into a single buffer. +/// +/// ``` +/// Quantized data matrix: +/// +/// --->|-----------------|<--- Quantization block width +/// | | +/// +-----------------+-----------------+----- ... +/// | q00 q01 q02 q03 | q04 q05 q06 q07 | ........ +/// | q10 q11 q12 q13 | q14 q15 q16 q17 | ........ +/// | q20 q21 q22 q23 | q24 q25 q26 q27 | ........ +/// | q30 q31 q32 q33 | q34 q35 q36 q37 | ........ +/// | ............... | ............... | ........ +/// : ............... : ............... : ........ +/// +/// Quantization scale matrix: +/// +/// +-----+-----+-- ... +/// | s00 | s01 | ..... +/// | s10 | s11 | ..... +/// | s20 | s21 | ..... +/// | s30 | s31 | ..... +/// | ... | ... | ..... +/// : ... : ... : ..... +/// ``` +/// +/// The packed data has each quantized block row followed by the corresponding quantization scale. +/// +/// This function is different from @ref pack_data_scales that in this packing method +/// the quantized data row is splitted into two halves and they are interleaved together. +/// +/// ``` +/// Packed data: +/// +/// +-----------------+-----+-----------------+-----+----- ... +/// | q00 q02 q01 q03 | s00 | q04 q06 q05 q07 | s01 | ........ +/// | q10 q12 q11 q13 | s10 | q14 q16 q15 q17 | s11 | ........ +/// | q20 q22 q21 q23 | s20 | q24 q26 q25 q27 | s21 | ........ +/// | q30 q32 q31 q33 | s30 | q34 q36 q35 q37 | s31 | ........ +/// | ............... | ... | ............... | ... | ........ +/// : ............... : ... : ............... : ... : ........ +/// ``` +/// +/// @tparam Data The data type of the quantized value. +/// @tparam Scale The data type of the quantization scale. +/// +/// @param[in] data The quantized data. +/// @param[in] scales The quantization scales. +/// @param[in] height The number of rows. +/// @param[in] width The number of columns. +/// @param[in] quant_width The number of columns of the quantization block. +/// +/// @return The packed data buffer. +template +std::vector pack_data_scales_interleave_block( + const void* data, const void* scales, size_t height, size_t width, size_t quant_width); + } // namespace kai::test diff --git a/test/reference/quantize.cpp b/test/reference/quantize.cpp index 144e7a4f..ad4a450f 100644 --- a/test/reference/quantize.cpp +++ b/test/reference/quantize.cpp @@ -13,8 +13,6 @@ #include #include -#include "kai/kai_common.h" -#include "test/common/data_type.hpp" #include "test/common/int4.hpp" #include "test/common/memory.hpp" #include "test/common/numeric_limits.hpp" @@ -23,9 +21,12 @@ namespace kai::test { -std::tuple get_qai8_scale_zero_point_from_range(float min_value, float max_value) { - constexpr float q_min = std::numeric_limits::min(); - constexpr float q_max = std::numeric_limits::max(); +namespace { + +template +std::tuple get_scale_zero_point_from_range(FloatData min_value, FloatData max_value) { + constexpr FloatData q_min = std::numeric_limits::min(); + constexpr FloatData q_max = std::numeric_limits::max(); if (min_value > 0) { min_value = 0; @@ -37,82 +38,159 @@ std::tuple get_qai8_scale_zero_point_from_range(float min_value, // The reason for computing the inverted scale first is to make it bit-perfect with quantized packing kernels. // If those kernels don't do it this way anymore, it makes more sense to calculate the scale directly. - const float inv_scale = max_value != min_value ? (q_max - q_min) / (max_value - min_value) : 1.0F; - const float scale = 1.0F / inv_scale; + const FloatData inv_scale = max_value != min_value ? (q_max - q_min) / (max_value - min_value) : 1.0F; + const FloatData scale = 1.0F / inv_scale; - const float scaled_min = min_value / scale; - const float scaled_max = max_value / scale; + const FloatData scaled_min = min_value / scale; + const FloatData scaled_max = max_value / scale; - const float zero_point_f = -(scaled_min + q_min) < scaled_max + q_max ? scaled_min - q_min : scaled_max - q_max; - const int32_t zero_point = round_to_nearest_even_i32(zero_point_f); + const FloatData zero_point_f = -(scaled_min + q_min) < scaled_max + q_max ? scaled_min - q_min : scaled_max - q_max; + const ZeroPoint zero_point = round_to_nearest_even(zero_point_f); return {scale, zero_point}; } -int8_t quantize_i8_fp32(float value, float scale, int32_t zero_point) { - return static_cast(std::clamp( - round_to_nearest_even_i32(value / scale) - zero_point, std::numeric_limits::min(), - std::numeric_limits::max())); +template +IntType quantize_symmetric(float value, float scale) { + const auto inv_scale = scale != 0 ? 1.0F / scale : 0.0F; + auto qsi32 = round_to_nearest_even_i32(value * inv_scale); + + if (is_unsigned) { + qsi32 += 1 << (size_in_bits - 1); + } + + return static_cast(std::clamp(qsi32, numeric_lowest, numeric_highest)); } -namespace { +template +IntType quantize_asymmetric(FloatType value, FloatType scale, ZeroPointType zero_point) { + const auto inv_scale = scale != 0 ? 1.0F / scale : 0.0F; + auto quantized_value = round_to_nearest_even(value * inv_scale) - zero_point; + return static_cast( + std::clamp(quantized_value, numeric_lowest, numeric_highest)); +} + +} // namespace -template -std::vector dequantize_any_type( - const void* data, const void* scales, const void* zero_points, // - QuantizationMethod method, bool is_asymm, size_t height, size_t width) { - static_assert(is_floating_point); - static_assert(is_integral); +template +std::tuple, std::vector> quantize_symmetric_per_block( + const void* src, size_t height, size_t width, size_t quant_width) { + static_assert(is_floating_point); + static_assert(is_integral); + static_assert(is_floating_point); - std::vector dst; - dst.resize(height * width * sizeof(Output)); - KAI_ASSUME(size_in_bits % 8 == 0); + const auto num_quant_packets_x = round_up_division(width, quant_width); - auto scale = read_array(scales, 0); - KAI_UNUSED(is_asymm); - KAI_UNUSED(zero_points); - auto zero_point = is_asymm ? read_array(zero_points, 0) : // - -static_cast(numeric_lowest>); + const auto data_bytes = round_up_division(height * width * size_in_bits, 8); + std::vector data(data_bytes); + + const auto scales_bytes = height * num_quant_packets_x * sizeof(ScaleType); + std::vector scales(scales_bytes); + + const auto* src_ptr = reinterpret_cast(src); for (size_t y = 0; y < height; ++y) { - if (method == QuantizationMethod::PER_ROW) { - scale = read_array(scales, y); + for (size_t x_quant = 0; x_quant < width; x_quant += quant_width) { + // Computes the quantization scale. + SrcType max_abs = 0; + + for (size_t x_element = 0; x_element < quant_width; ++x_element) { + const auto x = x_quant + x_element; - if (is_asymm) { - zero_point = read_array(zero_points, y); + if (x < width) { + max_abs = std::max(max_abs, std::abs(src_ptr[y * width + x])); + } } - } - for (size_t x = 0; x < width; ++x) { - const ZeroPoint input = read_array(data, y * width + x); - const Scale output = static_cast(input - zero_point) * scale; - write_array(dst.data(), y * width + x, output); + const auto scale = max_abs / ((1 << (size_in_bits - 1)) - 1); + + // Stores the scales. + write_array(scales.data(), y * num_quant_packets_x + x_quant / quant_width, scale); + + // Quantizes and stores the data. + for (size_t x_element = 0; x_element < quant_width; ++x_element) { + const auto x = x_quant + x_element; + + if (x < width) { + const auto quantized = quantize_symmetric(src_ptr[y * width + x], scale); + write_array(data.data(), y * width + x, quantized); + } + } } } - return dst; + return {data, scales}; } -} // namespace +template std::tuple, std::vector> quantize_symmetric_per_block( + const void* src, size_t height, size_t width, size_t quant_width); +template std::tuple, std::vector> quantize_symmetric_per_block( + const void* src, size_t height, size_t width, size_t quant_width); +template std::tuple, std::vector> quantize_symmetric_per_block( + const void* src, size_t height, size_t width, size_t quant_width); -std::vector dequantize( - const void* data, const void* scales, const void* zero_points, // - DataType src_dt, DataType dst_dt, QuantizationMethod method, // - size_t height, size_t width) { - switch (src_dt) { - case DataType::QSU4: - switch (dst_dt) { - case DataType::FP32: - return dequantize_any_type( - data, scales, zero_points, method, false, height, width); - - default: - KAI_ERROR("Unsupported destination data type!"); +template +std::tuple, std::vector, std::vector> quantize_asymmetric_per_block( + const void* src, size_t height, size_t width, size_t quant_width) { + static_assert(is_floating_point); + static_assert(is_integral); + static_assert(is_floating_point); + static_assert(is_integral); + + const auto num_quant_packets_x = round_up_division(width, quant_width); + + const auto data_bytes = round_up_division(height * width * size_in_bits, 8); + std::vector data(data_bytes); + + const auto scales_bytes = height * num_quant_packets_x * sizeof(ScaleType); + std::vector scales(scales_bytes); + + const auto zero_points_bytes = height * num_quant_packets_x * sizeof(ZeroPointType); + std::vector zero_points(zero_points_bytes); + + for (size_t y = 0; y < height; ++y) { + for (size_t x_quant = 0; x_quant < width; x_quant += quant_width) { + // Computes the quantization scale and zero point. + auto min_value = std::numeric_limits::max(); + auto max_value = std::numeric_limits::lowest(); + + for (size_t x_element = 0; x_element < quant_width; ++x_element) { + const auto x = x_quant + x_element; + + if (x < width) { + const auto value = read_array(src, y * width + x); + + min_value = std::min(min_value, value); + max_value = std::max(max_value, value); + } } - default: - KAI_ERROR("Unsupported source data type!"); + const auto [scale, zero_point] = + get_scale_zero_point_from_range(min_value, max_value); + + // Stores the scale and zero point. + write_array(scales.data(), y * num_quant_packets_x + x_quant / quant_width, scale); + write_array(zero_points.data(), y * num_quant_packets_x + x_quant / quant_width, zero_point); + + // Quantizes and stores the data. + for (size_t x_element = 0; x_element < quant_width; ++x_element) { + const auto x = x_quant + x_element; + + if (x < width) { + const auto value_f = read_array(src, y * width + x); + const auto value_q = + quantize_asymmetric(value_f, scale, zero_point); + + write_array(data.data(), y * width + x, value_q); + } + } + } } + + return {data, scales, zero_points}; } +template std::tuple, std::vector, std::vector> quantize_asymmetric_per_block< + float, int8_t, float, int32_t>(const void* src, size_t height, size_t width, size_t quant_width); + } // namespace kai::test diff --git a/test/reference/quantize.hpp b/test/reference/quantize.hpp index b9865903..58eb88bb 100644 --- a/test/reference/quantize.hpp +++ b/test/reference/quantize.hpp @@ -11,8 +11,6 @@ #include #include -#include "test/common/data_type.hpp" - namespace kai::test { /// Quantization method. @@ -21,40 +19,160 @@ enum class QuantizationMethod : uint32_t { PER_ROW, ///< Per-row, i.e. one quantization scale and zero point for each row. }; -/// Calculates the quantization information for 8-bit signed asymmetric type from the value range. +/// Quantizes each subblock of the matrix using symmetric quantization method. /// -/// @param[in] min_value Minimum value. -/// @param[in] max_value Maximum value. +/// The input matrix is divided into quantization blocks of the same size. /// -/// @return The scale and zero point. -std::tuple get_qai8_scale_zero_point_from_range(float min_value, float max_value); - -/// Quantizes the single-precision floating-point value using 8-bit asymmetric quantization. +/// The height of the block does not effect the behavior of this function hence it is omitted +/// from the function arguments and the figures below. +/// +/// ``` +/// Quantization blocks -------+ +/// | | +/// | | +/// v v +/// +-----------------+-----------------+----- ... +/// | f00 f01 f02 f03 | f04 f05 f06 f07 | ........ +/// | f10 f11 f12 f13 | f14 f15 f16 f17 | ........ +/// | f20 f21 f22 f23 | f24 f25 f26 f27 | ........ +/// | f30 f31 f32 f33 | f34 f35 f36 f37 | ........ +/// | ............... | ............... | ........ +/// : ............... : ............... : ........ +/// ``` +/// +/// Each row of the quantization block is quantized individually. +/// +/// ``` +/// Floating-point data Scale Quantized data +/// +-----------------+ +-----+ +-----------------+ +/// | f00 f01 f02 f03 | -------> | s00 | | q00 q01 q02 q03 | +/// | f10 f11 f12 f13 | -------> | s10 | | q10 q11 q12 q13 | +/// | f20 f21 f22 f23 | -------> | s20 | | q20 q21 q22 q23 | +/// | f30 f31 f32 f33 | -------> | s30 | | q30 q31 q32 q33 | +/// | ............... | | ... | | ............... | +/// : ............... : : ... : : ............... : +/// ``` +/// +/// The quantization scale and quantized data are stored in separate buffers. +/// +/// ``` +/// Quantized data matrix: /// -/// Formula: `q = f / scale + zero_point` where `q` is quantized value and `f` is floating-point value. +/// +-----------------+-----------------+----- ... +/// | q00 q01 q02 q03 | q04 q05 q06 q07 | ........ +/// | q10 q11 q12 q13 | q14 q15 q16 q17 | ........ +/// | q20 q21 q22 q23 | q24 q25 q26 q27 | ........ +/// | q30 q31 q32 q33 | q34 q35 q36 q37 | ........ +/// | ............... | ............... | ........ +/// : ............... : ............... : ........ /// -/// @param[in] value Value to be quantized. -/// @param[in] scale Scale. -/// @param[in] zero_point Zero point. +/// Quantization scale matrix: /// -/// @return The quantized value. -int8_t quantize_i8_fp32(float value, float scale, int32_t zero_point); +/// +-----+-----+-- ... +/// | s00 | s01 | ..... +/// | s10 | s11 | ..... +/// | s20 | s21 | ..... +/// | s30 | s31 | ..... +/// | ... | ... | ..... +/// : ... : ... : ..... +/// ``` +/// +/// @tparam SrcType The data type of the input data (must be floating-point). +/// @tparam DstType The data type of the output data (must be integer). +/// @tparam ScaleType The data type of the quantization scales (must be floating-point). +/// +/// @param[in] src The input matrix. +/// @param[in] height The number of rows. +/// @param[in] width The number of columns. +/// @param[in] quant_width The number of columns of the quantization block. +/// +/// @return The quantized data matrix and the quantization scale matrix. +template +std::tuple, std::vector> quantize_symmetric_per_block( + const void* src, size_t height, size_t width, size_t quant_width); -/// Dequantizes the matrix to floating-point. -/// -/// @param[in] data Quantized data buffer. -/// @param[in] scales Quantization scales. -/// @param[in] zero_points (Optional) Quantization zero points. -/// @param[in] src_dt Quantized data type. -/// @param[in] dst_dt Dequantized data type. -/// @param[in] method Quantization method. -/// @param[in] height Number of rows. -/// @param[in] width Number of columns. -/// -/// @return The dequantized data buffer. -std::vector dequantize( - const void* data, const void* scales, const void* zero_points, // - DataType src_dt, DataType dst_dt, QuantizationMethod method, // - size_t height, size_t width); +/// Quantizes each subblock of the matrix using asymmetric quantization method. +/// +/// The input matrix is divided into quantization blocks of the same size. +/// +/// The height of the block does not effect the behavior of this function hence it is omitted +/// from the function arguments and the figures below. +/// +/// ``` +/// Quantization blocks -------+ +/// | | +/// | | +/// v v +/// +-----------------+-----------------+----- ... +/// | f00 f01 f02 f03 | f04 f05 f06 f07 | ........ +/// | f10 f11 f12 f13 | f14 f15 f16 f17 | ........ +/// | f20 f21 f22 f23 | f24 f25 f26 f27 | ........ +/// | f30 f31 f32 f33 | f34 f35 f36 f37 | ........ +/// | ............... | ............... | ........ +/// : ............... : ............... : ........ +/// ``` +/// +/// Each row of the quantization block is quantized individually. +/// +/// ``` +/// Floating-point data Scale Zero point Quantized data +/// +-----------------+ +-----+ +-----+ +-----------------+ +/// | f00 f01 f02 f03 | -------> | s00 | | z00 | | q00 q01 q02 q03 | +/// | f10 f11 f12 f13 | -------> | s10 | | z10 | | q10 q11 q12 q13 | +/// | f20 f21 f22 f23 | -------> | s20 | | z20 | | q20 q21 q22 q23 | +/// | f30 f31 f32 f33 | -------> | s30 | | z30 | | q30 q31 q32 q33 | +/// | ............... | | ... | | ... | | ............... | +/// : ............... : : ... : : ... : : ............... : +/// ``` +/// +/// The quantization scales, zero points quantized data are stored in separate buffers. +/// +/// ``` +/// Quantized data matrix: +/// +/// +-----------------+-----------------+----- ... +/// | q00 q01 q02 q03 | q04 q05 q06 q07 | ........ +/// | q10 q11 q12 q13 | q14 q15 q16 q17 | ........ +/// | q20 q21 q22 q23 | q24 q25 q26 q27 | ........ +/// | q30 q31 q32 q33 | q34 q35 q36 q37 | ........ +/// | ............... | ............... | ........ +/// : ............... : ............... : ........ +/// +/// Quantization scale matrix: +/// +/// +-----+-----+-- ... +/// | s00 | s01 | ..... +/// | s10 | s11 | ..... +/// | s20 | s21 | ..... +/// | s30 | s31 | ..... +/// | ... | ... | ..... +/// : ... : ... : ..... +/// ``` +/// +/// Quantization zero point matrix: +/// +/// +-----+-----+-- ... +/// | z00 | z01 | ..... +/// | z10 | z11 | ..... +/// | z20 | z21 | ..... +/// | z30 | z31 | ..... +/// | ... | ... | ..... +/// : ... : ... : ..... +/// ``` +/// +/// @tparam SrcType The data type of the input data (must be floating-point). +/// @tparam DstType The data type of the output data (must be integer). +/// @tparam ScaleType The data type of the quantization scales (must be floating-point). +/// @tparam ZeroPointType The data type of the quantization zero points (must be integer). +/// +/// @param[in] src The input matrix. +/// @param[in] height The number of rows. +/// @param[in] width The number of columns. +/// @param[in] quant_width The number of columns of the quantization block. +/// +/// @return The quantized data matrix, the scale matrix and the zero point matrix. +template +std::tuple, std::vector, std::vector> quantize_asymmetric_per_block( + const void* src, size_t height, size_t width, size_t quant_width); } // namespace kai::test diff --git a/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp b/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp new file mode 100644 index 00000000..8407c15f --- /dev/null +++ b/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp @@ -0,0 +1,96 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include +#include + +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h" +#include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h" +#include "test/common/int4.hpp" +#include "test/common/memory.hpp" +#include "test/reference/cast.hpp" +#include "test/reference/fill.hpp" +#include "test/reference/matmul.hpp" +#include "test/reference/quantize.hpp" + +namespace kai::test { + +TEST(matmul_clamp_f32_qai8dxp_qai4cxp, EndToEnd) { + const uint64_t seed = 0; + + const size_t M = 16; + const size_t N = 32; + const size_t K = 64; + + const auto mr = kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm(); + const auto nr = kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm(); + const auto kr = kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm(); + const auto sr = kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm(); + + // Generates input data. + const auto ref_lhs = fill_random(M * K, seed + 0); + const auto ref_rhs = fill_random(N * K, seed + 1); + const auto ref_biases = fill_random(N, seed + 2); + + // Runs the reference implementation. + // * Quantizes the LHS matrix using 8-bit asymmetric quantization. + // * Quantizes the RHS matrix using 4-bit symmetric quantization. + // * Performs GEMM. + const auto [ref_lhs_qvalues, ref_lhs_scales, ref_lhs_zero_points] = + quantize_asymmetric_per_block(ref_lhs.data(), M, K, K); + const auto [ref_rhs_qsi4, ref_rhs_scales] = + quantize_symmetric_per_block(ref_rhs.data(), N, K, K); + + const auto ref_dst = matmul_clamp_nt_t( + M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), K, ref_rhs_qsi4.data(), + ref_rhs_scales.data(), nullptr, K, ref_biases.data(), std::numeric_limits::lowest(), + std::numeric_limits::max()); + + // Runs the LHS packing micro-kernel. + const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr); + std::vector imp_packed_lhs(imp_packed_lhs_size); + kai_run_lhs_quant_pack_qai8dxp_f32( + M, K, mr, kr, sr, 0, reinterpret_cast(ref_lhs.data()), K * sizeof(float), imp_packed_lhs.data()); + + // Runs the RHS packing micro-kernel. + // * Generates the 4-bit unsigned symmetric quantized input for the micro-kernel. + // * Packs the RHS matrix. + const auto ref_rhs_qsu4 = cast_qsu4_qsi4(ref_rhs_qsi4.data(), N * K); + const auto imp_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(N, K, nr, kr, sr); + std::vector imp_packed_rhs(imp_packed_rhs_size); + const kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0_params params{.lhs_zero_point = 1, .rhs_zero_point = 8}; + kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( + 1, N, K, nr, kr, sr, ref_rhs_qsu4.data(), reinterpret_cast(ref_biases.data()), + reinterpret_cast(ref_rhs_scales.data()), imp_packed_rhs.data(), 0, ¶ms); + + // Runs the GEMM micro-kernel. + const auto imp_dst_size = kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm(M, N); + ASSERT_EQ(imp_dst_size, ref_dst.size()); + std::vector imp_dst(imp_dst_size); + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm( + M, N, K, imp_packed_lhs.data(), imp_packed_rhs.data(), reinterpret_cast(imp_dst.data()), + N * sizeof(float), sizeof(float), std::numeric_limits::lowest(), std::numeric_limits::max()); + + // Compares the output of the micro-kernels against the output of the reference implementation. + for (size_t y = 0; y < M; ++y) { + for (size_t x = 0; x < N; ++x) { + const auto imp_value = read_array(imp_dst.data(), y * N + x); + const auto ref_value = read_array(ref_dst.data(), y * N + x); + const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : std::abs(imp_value); + + if (rel_error > 0.0001F) { + ASSERT_EQ(imp_value, ref_value); + } + } + } +} + +} // namespace kai::test -- GitLab From aa6a13f57675e00d70b39ce7ff2185e3903d3178 Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Thu, 15 Aug 2024 11:54:36 +0100 Subject: [PATCH 2/3] Fix compilation error and add more documentation Signed-off-by: Viet-Hoa Do --- test/reference/matmul.cpp | 2 +- test/reference/matmul.hpp | 30 ++++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/test/reference/matmul.cpp b/test/reference/matmul.cpp index ea764ca5..72d03bb2 100644 --- a/test/reference/matmul.cpp +++ b/test/reference/matmul.cpp @@ -6,6 +6,7 @@ #include "test/reference/matmul.hpp" +#include #include #include #include @@ -20,7 +21,6 @@ #include "test/reference/binary_elementwise.hpp" #include "test/reference/cast.hpp" #include "test/reference/pack.hpp" -#include "test/reference/quantize.hpp" #include "test/reference/reduce.hpp" #include "test/reference/transpose.hpp" diff --git a/test/reference/matmul.hpp b/test/reference/matmul.hpp index 68b5de90..40fb684d 100644 --- a/test/reference/matmul.hpp +++ b/test/reference/matmul.hpp @@ -65,6 +65,36 @@ std::vector matmul( size_t m, size_t n, size_t k, // bool lhs_transposed, bool rhs_transposed); +/// Matrix multiplication with quantized input and floating-point output. +/// +/// The LHS matrix is non-transposed and the RHS matrix is transposed. +/// +/// @tparam LhsData The data type of the LHS matrix. +/// @tparam LhsScale The data type of the quantization scales of the LHS matrix. +/// @tparam LhsZeroPoint The data type of the quantization zero points of the LHS matrix. +/// @tparam Rhsdata The data type of the RHS matrix. +/// @tparam RhsScale The data type of the quantization scales of the RHS matrix. +/// @tparam RhsZeroPoint The data type of the quantization zero points of the RHS matrix. +/// @tparam Bias The data type of the bias vector. +/// @tparam IntAcc The data type of the intermediate integer accumulator. +/// @tparam DstData The data type of the floating-point accumulator and the output matrix. +/// +/// @param[in] m The LHS and output height. +/// @param[in] n The RHS height and output width. +/// @param[in] k The LHS and RHS width. +/// @param[in] lhs_data The LHS data matrix. +/// @param[in] lhs_scales The LHS quantization scales matrix. +/// @param[in] lhs_zero_points The LHS quantization zero points matrix. +/// @param[in] lhs_quant_width The LHS quantization block width. +/// @param[in] rhs_data The RHS data matrix. +/// @param[in] rhs_scales The RHS quantization scales matrix. +/// @param[in] rhs_zero_points The RHS quantization zero points matrix. +/// @param[in] rhs_quant_width The RHS quantization block width. +/// @param[in] biases The biases vector. +/// @param[in] min_value The minimum output value. +/// @param[in] max_value The maximum output value. +/// +/// @return The output matrix. template < typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale, typename RhsZeroPoint, typename Bias, typename IntAcc, typename DstData> -- GitLab From 8ae5521f3b78bd13f2daa00104952f255ff4afe5 Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Thu, 15 Aug 2024 14:51:48 +0100 Subject: [PATCH 3/3] Update bazel build and fix documentation Signed-off-by: Viet-Hoa Do --- test/BUILD.bazel | 1 + test/reference/pack.hpp | 33 ++++++++++++++++----------------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/test/BUILD.bazel b/test/BUILD.bazel index 85d14acd..2ad11215 100644 --- a/test/BUILD.bazel +++ b/test/BUILD.bazel @@ -53,6 +53,7 @@ kai_cxx_library( cc_test( name = "kleidiai_test", srcs = [ + "tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp", "tests/matmul_test.cpp", ], copts = kai_cxxopts(kai_cpu_bf16() + kai_cpu_fp16()), diff --git a/test/reference/pack.hpp b/test/reference/pack.hpp index 60f19591..8564c810 100644 --- a/test/reference/pack.hpp +++ b/test/reference/pack.hpp @@ -8,7 +8,6 @@ #include #include -#include #include namespace kai::test { @@ -52,18 +51,18 @@ std::vector pack( /// : ... : ... : ..... /// ``` /// -/// The packed data has each quantized block row followed by the corresponding quantization scale. +/// The packed data has each quantization scale followed by the quantized block row. /// /// ``` /// Packed data: /// -/// +-----------------+-----+-----------------+-----+----- ... -/// | q00 q01 q02 q03 | s00 | q04 q05 q06 q07 | s01 | ........ -/// | q10 q11 q12 q13 | s10 | q14 q15 q16 q17 | s11 | ........ -/// | q20 q21 q22 q23 | s20 | q24 q25 q26 q27 | s21 | ........ -/// | q30 q31 q32 q33 | s30 | q34 q35 q36 q37 | s31 | ........ -/// | ............... | ... | ............... | ... | ........ -/// : ............... : ... : ............... : ... : ........ +/// +-----+-----------------+-----+-----------------+----- ... +/// | s00 | q00 q01 q02 q03 | s01 | q04 q05 q06 q07 | ........ +/// | s10 | q10 q11 q12 q13 | s11 | q14 q15 q16 q17 | ........ +/// | s20 | q20 q21 q22 q23 | s21 | q24 q25 q26 q27 | ........ +/// | s30 | q30 q31 q32 q33 | s31 | q34 q35 q36 q37 | ........ +/// | ... | ............... | ... | ............... | ........ +/// : ... : ............... : ... : ............... : ........ /// ``` /// /// @tparam Data The data type of the quantized value. @@ -106,7 +105,7 @@ std::vector pack_data_scales( /// : ... : ... : ..... /// ``` /// -/// The packed data has each quantized block row followed by the corresponding quantization scale. +/// The packed data has each quantization scale followed by the quantized block row. /// /// This function is different from @ref pack_data_scales that in this packing method /// the quantized data row is splitted into two halves and they are interleaved together. @@ -114,13 +113,13 @@ std::vector pack_data_scales( /// ``` /// Packed data: /// -/// +-----------------+-----+-----------------+-----+----- ... -/// | q00 q02 q01 q03 | s00 | q04 q06 q05 q07 | s01 | ........ -/// | q10 q12 q11 q13 | s10 | q14 q16 q15 q17 | s11 | ........ -/// | q20 q22 q21 q23 | s20 | q24 q26 q25 q27 | s21 | ........ -/// | q30 q32 q31 q33 | s30 | q34 q36 q35 q37 | s31 | ........ -/// | ............... | ... | ............... | ... | ........ -/// : ............... : ... : ............... : ... : ........ +/// +-----+-----------------+-----+-----------------+----- ... +/// | s00 | q00 q02 q01 q03 | s01 | q04 q06 q05 q07 | ........ +/// | s10 | q10 q12 q11 q13 | s11 | q14 q16 q15 q17 | ........ +/// | s20 | q20 q22 q21 q23 | s21 | q24 q26 q25 q27 | ........ +/// | s30 | q30 q32 q31 q33 | s31 | q34 q36 q35 q37 | ........ +/// | ... | ............... | ... | ............... | ........ +/// : ... : ............... : ... : ............... : ........ /// ``` /// /// @tparam Data The data type of the quantized value. -- GitLab