diff --git a/CMakeLists.txt b/CMakeLists.txt index 642fce8d016c1163d42bc1b427e98102f37d7bae..389b0a4ea1ad743bc55cbf5beb4cef7b83adeb24 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -118,6 +118,7 @@ set(KLEIDIAI_FILES_SME kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.c kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32pb_f32_f32_16vlx1_sme.c + kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cp2vlx4sb_qsi8_f32_i32_sme.c ) set(KLEIDIAI_FILES_SME2 @@ -180,12 +181,14 @@ if(KLEIDIAI_BUILD_TESTS) test/reference/binary_elementwise.cpp test/reference/matmul.cpp + test/reference/matmul_pack.cpp test/reference/fill.cpp test/reference/pack.cpp test/reference/quantize.cpp test/reference/reduce.cpp test/reference/transpose.cpp test/reference/cast.cpp + test/reference/reorder.cpp ) target_compile_options(kleidiai_test_framework @@ -205,6 +208,7 @@ if(KLEIDIAI_BUILD_TESTS) test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp test/tests/matmul_clamp_f32_qsi8d32p_qsi4c32p_test.cpp test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp + test/tests/matmul_clamp_qai8_qai8p_qsi8cp_test.cpp ) target_link_libraries(kleidiai_test diff --git a/kai/kai_common.h b/kai/kai_common.h index 2703418542324f4c7a69635a88aae45a86d9d2df..6e46901e770f2d264db872f59b8061b29bc04673 100644 --- a/kai/kai_common.h +++ b/kai/kai_common.h @@ -161,6 +161,11 @@ inline static int8_t kai_ext_sign_i8_i4(int8_t value) { return (value ^ 0x8) - 8; } +struct kai_rhs_pack_qsi8_params { + int32_t input_zero_point; + float scale_multiplier; +}; + #ifdef __cplusplus } #endif diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cp2vlx4sb_qsi8_f32_i32_sme.c b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cp2vlx4sb_qsi8_f32_i32_sme.c new file mode 100644 index 0000000000000000000000000000000000000000..ac18e9cd0597cbfda13e05889889dad69965a52a --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cp2vlx4sb_qsi8_f32_i32_sme.c @@ -0,0 +1,199 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural features check. + +#include "kai_rhs_pack_kxn_qsi8cp2vlx4sb_qsi8_f32_i32_sme.h" + +#include +#include +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_nr = 2; +static const size_t kai_kr = 4; +static const size_t kai_num_bytes_input = 1; +static const size_t kai_num_bytes_output = 1; +static const size_t kai_num_bytes_bias = 4; +static const size_t kai_num_bytes_scale = 4; + +size_t kai_get_n_step_rhs_pack_kxn_qsi8cp2vlx4sb_qsi8_f32_i32_sme(void) { + return kai_nr * kai_get_sme_vector_length_u8(); +} + +size_t kai_get_rhs_offset_rhs_pack_kxn_qsi8cp2vlx4sb_qsi8_f32_i32_sme(size_t n_idx) { + KAI_ASSUME(n_idx % (kai_nr * kai_get_sme_vector_length_u8()) == 0); + + return n_idx * kai_num_bytes_input; +} + +size_t kai_get_bias_offset_rhs_pack_kxn_qsi8cp2vlx4sb_qsi8_f32_i32_sme(size_t n_idx) { + return n_idx * kai_num_bytes_bias; +} + +size_t kai_get_rhs_packed_stride_rhs_pack_kxn_qsi8cp2vlx4sb_qsi8_f32_i32_sme(size_t k) { + return kai_nr * kai_get_sme_vector_length_u8() / kai_kr * + (kai_num_bytes_bias + kai_roundup(k, kai_kr) * kai_num_bytes_output + kai_num_bytes_scale); +} + +size_t kai_get_rhs_packed_offset_rhs_pack_kxn_qsi8cp2vlx4sb_qsi8_f32_i32_sme(size_t n_idx, size_t k) { + KAI_ASSUME(n_idx % (kai_nr * kai_get_sme_vector_length_u8() / kai_kr) == 0); + + return n_idx * (kai_num_bytes_bias + kai_roundup(k, kai_kr) * kai_num_bytes_output + kai_num_bytes_scale); +} + +size_t kai_get_rhs_packed_size_rhs_pack_kxn_qsi8cp2vlx4sb_qsi8_f32_i32_sme(size_t n, size_t k) { + return kai_get_rhs_packed_offset_rhs_pack_kxn_qsi8cp2vlx4sb_qsi8_f32_i32_sme( + kai_roundup(n, kai_nr * kai_get_sme_vector_length_u8() / kai_kr), k); +} + +void kai_run_rhs_pack_kxn_qsi8cp2vlx4sb_qsi8_f32_i32_sme( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, + const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, + const struct kai_rhs_pack_qsi8_params* params) { + KAI_ASSUME(num_groups == 1); + KAI_ASSUME(nr == kai_nr * kai_get_sme_vector_length_u8() / kai_kr); + KAI_ASSUME(kr == kai_kr); + KAI_ASSUME(sr == 1); + KAI_ASSUME(rhs != NULL); + KAI_ASSUME(bias != NULL); + KAI_ASSUME(rhs_packed != NULL); + KAI_ASSUME(extra_bytes == 0); + KAI_ASSUME(params != NULL); + + size_t height = k; + const size_t width = n; + const void* in = rhs; + void* out = rhs_packed; + const size_t in_stride = rhs_stride; + uint8_t* pad_row = (uint8_t*)alloca(width * sizeof(uint8_t)); + + if (height % 4) { + memset(pad_row, 0, width * sizeof(uint8_t)); + } + + size_t out_stride = kai_get_rhs_packed_stride_rhs_pack_kxn_qsi8cp2vlx4sb_qsi8_f32_i32_sme(height); + const int32_t input_zero_point = params->input_zero_point; + const float scale_multiplier = params->scale_multiplier; + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "mov x27, %x[out]\n" + "mov x26, %x[height]\n" + "ptrue p2.b\n" + "incb %x[out], ALL, MUL #2\n" + "1:" // Main row loop: Head + "mov x25, %x[in]\n" + "cmp %x[height], #0x3\n" + "add x24, x25, %x[in_stride]\n" + "mov x23, %x[out]\n" + "add x22, x24, %x[in_stride]\n" + "mov x21, %x[width]\n" + "add x20, x22, %x[in_stride]\n" + "csel x22, x22, %x[pad_row], GE\n" + "add %x[in], x20, %x[in_stride]\n" + "csel x20, x20, %x[pad_row], GT\n" + "cmp %x[height], #0x1\n" + "sub %x[height], %x[height], #0x4\n" + "csel x24, x24, %x[pad_row], GT\n" + "2:" // Main row loop: Column loop + "whilelt p0.b, XZR, x21\n" + "decw x21, ALL, MUL #2\n" + "ld1b { z18.b }, p0/Z, [x25]\n" + "cmp x21, #0x0\n" + "incd x25, ALL, MUL #4\n" + "ld1b { z19.b }, p0/Z, [x24]\n" + "incd x24, ALL, MUL #4\n" + "ld1b { z17.b }, p0/Z, [x22]\n" + "incd x22, ALL, MUL #4\n" + "ld1b { z16.b }, p0/Z, [x20]\n" + "incd x20, ALL, MUL #4\n" + "zip1 z18.b, z18.b, z17.b\n" + "zip1 z16.b, z19.b, z16.b\n" + "zip1 z17.b, z18.b, z16.b\n" + "zip2 z16.b, z18.b, z16.b\n" + "st1b { z17.b }, p2, [x23]\n" + "st1b { z16.b }, p2, [x23, #1, MUL VL]\n" + "add x23, x23, %x[out_stride]\n" + "bgt 2b\n" + "cmp %x[height], #0x1\n" + "addvl %x[out], %x[out], #2\n" + "bge 1b\n" + "mov x22, %x[out]\n" + "mov x21, %x[width]\n" + "dup z18.s, %w[scale_multiplier]\n" + "cbz %x[scale], 5f\n" + "4:" // Scale: Full loop + "mov x20, x21\n" + "decw x21, ALL, MUL #2\n" + "whilelt p1.s, XZR, x20\n" + "decw x20\n" + "whilelt p0.s, XZR, x20\n" + "ld1w { z17.s }, p1/Z, [%x[scale]]\n" + "cmp x21, #0x0\n" + "ld1w { z16.s }, p0/Z, [%x[scale], #1, MUL VL]\n" + "incb %x[scale], ALL, MUL #2\n" + "fmul z17.s, z17.s, z18.s\n" + "fmul z16.s, z16.s, z18.s\n" + "st1w { z17.s }, p2, [x22]\n" + "st1w { z16.s }, p2, [x22, #1, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bgt 4b\n" + "5:" // Scale: Done + "cbz %x[width], 8f\n" + "cbz x26, 8f\n" + "dup z21.s, %w[input_zero_point]\n" + "add x25, x26, #0x3\n" + "cntw x24, ALL, MUL #2\n" + "mov z20.b, #0x1\n" + "lsr x25, x25, #0x2\n" + "mov x23, %x[width]\n" + "addvl x22, x27, #2\n" + "neg z21.s, p2/M, z21.s\n" + "6:" // Bias: N loop + "mov x21, x22\n" + "mov x20, x25\n" + "mov z19.s, #0x0\n" + "mov z18.s, #0x0\n" + "7:" // Bias: K loop + "ld1b { z17.b }, p2/Z, [x21]\n" + "subs x20, x20, #0x1\n" + "ld1b { z16.b }, p2/Z, [x21, #1, MUL VL]\n" + "addvl x21, x21, #2\n" + "sdot z19.s, z17.b, z20.b\n" + "sdot z18.s, z16.b, z20.b\n" + "bgt 7b\n" + "mov x20, x23\n" + "add x22, x22, %x[out_stride]\n" + "whilelt p1.s, XZR, x20\n" + "decw x20\n" + "whilelt p0.s, XZR, x20\n" + "ld1w { z17.s }, p1/Z, [%x[bias]]\n" + "subs x23, x23, x24\n" + "ld1w { z16.s }, p0/Z, [%x[bias], #1, MUL VL]\n" + "addvl %x[bias], %x[bias], #2\n" + "mla z17.s, p2/M, z19.s, z21.s\n" + "mla z16.s, p2/M, z18.s, z21.s\n" + "st1w { z17.s }, p2, [x27]\n" + "st1w { z16.s }, p2, [x27, #1, MUL VL]\n" + "add x27, x27, %x[out_stride]\n" + "bgt 6b\n" + "8:" // Bias: Done + ".inst 0xd503467f // SMSTOP\n" + : [bias] "+&r"(bias), [height] "+&r"(height), [in] "+&r"(in), [out] "+&r"(out), [scale] "+&r"(scale) + : [in_stride] "r"(in_stride), [input_zero_point] "r"(input_zero_point), [out_stride] "r"(out_stride), + [pad_row] "r"(pad_row), [scale_multiplier] "r"(scale_multiplier), [width] "r"(width) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", + "p15", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", + "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", + "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cp2vlx4sb_qsi8_f32_i32_sme.h b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cp2vlx4sb_qsi8_f32_i32_sme.h new file mode 100644 index 0000000000000000000000000000000000000000..ac24eb12cdc67ced56ebc3c08bbbae807c5d6237 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cp2vlx4sb_qsi8_f32_i32_sme.h @@ -0,0 +1,83 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#include "kai/kai_common.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Gets n step value. +/// +/// The starting row index must be divisible by `n_step`. +/// +/// @return The n step value. +size_t kai_get_n_step_rhs_pack_kxn_qsi8cp2vlx4sb_qsi8_f32_i32_sme(void); + +/// Gets the offset in bytes to the data element in the RHS matrix buffer. +/// +/// @param[in] n_idx Column index. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_offset_rhs_pack_kxn_qsi8cp2vlx4sb_qsi8_f32_i32_sme(size_t n_idx); + +/// Gets the offset in bytes to the data element in the bias buffer. +/// +/// @param[in] n_idx Column index. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_bias_offset_rhs_pack_kxn_qsi8cp2vlx4sb_qsi8_f32_i32_sme(size_t n_idx); + +/// Gets the offset in bytes to the data element in the packed RHS buffer. +/// +/// @param[in] n_idx Row index. +/// @param[in] k Number of columns. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_packed_offset_rhs_pack_kxn_qsi8cp2vlx4sb_qsi8_f32_i32_sme(size_t n_idx, size_t k); + +/// Gets the size in bytes of the packed RHS buffer. +/// +/// @param[in] n Number of rows. +/// @param[in] k Number of columns. +/// +/// @return The size in bytes of the packed RHS buffer. +size_t kai_get_rhs_packed_size_rhs_pack_kxn_qsi8cp2vlx4sb_qsi8_f32_i32_sme(size_t n, size_t k); + +/// Runs the RHS packing function for matrix multiplication. +/// +/// The pointer of each buffers (RHS, bias and packed RHS) needs to be added with offset +/// calculated using the following functions: +/// +/// * RHS: @ref kai_get_rhs_offset_rhs_pack_kxn_qsi8cp2vlx4sb_qsi8_f32_i32_sme. +/// * Bias: @ref kai_get_bias_offset_rhs_pack_kxn_qsi8cp2vlx4sb_qsi8_f32_i32_sme. +/// * Output: @ref kai_get_rhs_packed_offset_rhs_pack_kxn_qsi8cp2vlx4sb_qsi8_f32_i32_sme. +/// +/// @param[in] num_groups Number of groups. It must be 1. +/// @param[in] n Number of columns of the output matrix. +/// @param[in] k Common dimension between the LHS and RHS matrix. +/// @param[in] nr Block size in N dimension. It must be 2 * kai_get_sme_vector_length_u8(). +/// @param[in] kr Block size in K dimension. It must be 4. +/// @param[in] sr Number of kr splits. It must be 1. +/// @param[in] rhs_stride Row stride in bytes of the RHS matrix. +/// @param[in] rhs RHS matrix data buffer. +/// @param[in] bias Bias matrix data buffer. +/// @param[in] scale Scale data buffer. It must be NULL. +/// @param[out] rhs_packed Packed RHS matrix. +/// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. It must be 0. +/// @param[in] params Extra packing parameters. It must be NULL. +void kai_run_rhs_pack_kxn_qsi8cp2vlx4sb_qsi8_f32_i32_sme( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, + const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, + const struct kai_rhs_pack_qsi8_params* params); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/test/reference/binary_elementwise.cpp b/test/reference/binary_elementwise.cpp index d5abd26cc18422ee78334b0a63d3f9f9a907c076..f65136572c6689a919eea05359e3e539a4f9a2ed 100644 --- a/test/reference/binary_elementwise.cpp +++ b/test/reference/binary_elementwise.cpp @@ -136,6 +136,18 @@ std::vector sub( lhs, lhs_dt, lhs_height, lhs_width, rhs, rhs_dt, rhs_height, rhs_width); } +template +std::vector sub( + const void* lhs, size_t lhs_height, size_t lhs_width, // + const void* rhs, size_t rhs_height, size_t rhs_width) { + return binary_elementwise_any_op_type( + lhs, rhs, lhs_height, lhs_width, rhs_height, rhs_width); +} + +template std::vector sub( + const void* lhs, size_t lhs_height, size_t lhs_width, // + const void* rhs, size_t rhs_height, size_t rhs_width); + std::vector mul( const void* lhs, DataType lhs_dt, size_t lhs_height, size_t lhs_width, // const void* rhs, DataType rhs_dt, size_t rhs_height, size_t rhs_width) { @@ -143,6 +155,22 @@ std::vector mul( lhs, lhs_dt, lhs_height, lhs_width, rhs, rhs_dt, rhs_height, rhs_width); } +template +std::vector mul( + const void* lhs, size_t lhs_height, size_t lhs_width, // + const void* rhs, size_t rhs_height, size_t rhs_width) { + return binary_elementwise_any_op_type( + lhs, rhs, lhs_height, lhs_width, rhs_height, rhs_width); +} + +template std::vector mul( + const void* lhs, size_t lhs_height, size_t lhs_width, // + const void* rhs, size_t rhs_height, size_t rhs_width); + +template std::vector mul( + const void* lhs, size_t lhs_height, size_t lhs_width, // + const void* rhs, size_t rhs_height, size_t rhs_width); + std::vector div( const void* lhs, DataType lhs_dt, size_t lhs_height, size_t lhs_width, // const void* rhs, DataType rhs_dt, size_t rhs_height, size_t rhs_width) { diff --git a/test/reference/binary_elementwise.hpp b/test/reference/binary_elementwise.hpp index e3d3a9e1fdc36c7846859ebced2e2a09ae8938a5..f2f5c0ab5522748433ecee11a2a4b6552fc834df 100644 --- a/test/reference/binary_elementwise.hpp +++ b/test/reference/binary_elementwise.hpp @@ -50,6 +50,25 @@ std::vector sub( const void* lhs, DataType lhs_dt, size_t lhs_height, size_t lhs_width, // const void* rhs, DataType rhs_dt, size_t rhs_height, size_t rhs_width); +/// Elementwise subtraction. +/// +/// Broadcasting is supported for any dimension and both LHS and RHS operands. +/// +/// @tparam T The data type. +/// +/// @param[in] lhs The LHS data buffer. +/// @param[in] lhs_height The number of rows of the LHS matrix. +/// @param[in] lhs_width The number of columns of the LHS matrix. +/// @param[in] rhs The RHS data buffer. +/// @param[in] rhs_height The number of rows of the RHS matrix. +/// @param[in] rhs_width The number of columns of the LHS matrix. +/// +/// @return The result matrix. +template +std::vector sub( + const void* lhs, size_t lhs_height, size_t lhs_width, // + const void* rhs, size_t rhs_height, size_t rhs_width); + /// Elementwise multiplication. /// /// Broadcasting is supported for any dimension and both LHS and RHS operands. @@ -68,6 +87,25 @@ std::vector mul( const void* lhs, DataType lhs_dt, size_t lhs_height, size_t lhs_width, // const void* rhs, DataType rhs_dt, size_t rhs_height, size_t rhs_width); +/// Elementwise multiplication. +/// +/// Broadcasting is supported for any dimension and both LHS and RHS operands. +/// +/// @tparam T The data type. +/// +/// @param[in] lhs The LHS data buffer. +/// @param[in] lhs_height The number of rows of the LHS matrix. +/// @param[in] lhs_width The number of columns of the LHS matrix. +/// @param[in] rhs The RHS data buffer. +/// @param[in] rhs_height The number of rows of the RHS matrix. +/// @param[in] rhs_width The number of columns of the LHS matrix. +/// +/// @return The result matrix. +template +std::vector mul( + const void* lhs, size_t lhs_height, size_t lhs_width, // + const void* rhs, size_t rhs_height, size_t rhs_width); + /// Elementwise division. /// /// Broadcasting is supported for any dimension and both LHS and RHS operands. diff --git a/test/reference/fill.cpp b/test/reference/fill.cpp index 49b16987b8c9b1399db247ea1fdced6322e5382e..faad8547b5d9576a7e78a0bb3c34bf698b094422 100644 --- a/test/reference/fill.cpp +++ b/test/reference/fill.cpp @@ -20,6 +20,9 @@ #include "test/common/float16.hpp" #include "test/common/int4.hpp" #include "test/common/memory.hpp" +#include "test/common/numeric_limits.hpp" +#include "test/common/round.hpp" +#include "test/common/type_traits.hpp" namespace kai::test { @@ -118,11 +121,67 @@ std::vector fill_matrix_random(size_t height, size_t width, const DataF } } +template +Value get_random(uint64_t seed, Value min_value, Value max_value) { + static_assert(is_floating_point || is_integral); + static_assert(size_in_bits <= 32); + + using Distribution = std::conditional_t< + is_floating_point, std::uniform_real_distribution, + std::conditional_t< + is_signed, std::uniform_int_distribution, std::uniform_int_distribution>>; + + std::mt19937 rnd(seed); + Distribution dist(min_value, max_value); + + return static_cast(dist(rnd)); +} + +template +Value get_random(uint64_t seed) { + if constexpr (is_floating_point) { + return get_random(seed, static_cast(0.0F), static_cast(1.0F)); + } else { + return get_random(seed, numeric_lowest, numeric_highest); + } +} + +template float get_random(uint64_t seed); +template int32_t get_random(uint64_t seed); + +template +std::vector fill_random(size_t length, uint64_t seed, Value min_value, Value max_value) { + static_assert(is_floating_point || is_integral); + static_assert(size_in_bits <= 32); + + using Distribution = std::conditional_t< + is_floating_point, std::uniform_real_distribution, + std::conditional_t< + is_signed, std::uniform_int_distribution, std::uniform_int_distribution>>; + + std::mt19937 rnd(seed); + Distribution dist(min_value, max_value); + + std::vector data(round_up_division(length * size_in_bits, 8)); + + for (size_t i = 0; i < length; ++i) { + write_array(data.data(), i, static_cast(dist(rnd))); + } + + return data; +} + template std::vector fill_random(size_t length, uint64_t seed) { - return fill_matrix_random_raw(1, length, seed); + if constexpr (is_floating_point) { + return fill_random(length, seed, static_cast(0.0F), static_cast(1.0F)); + } else { + return fill_random(length, seed, numeric_lowest, numeric_highest); + } } template std::vector fill_random(size_t length, uint64_t seed); +template std::vector fill_random(size_t length, uint64_t seed); +template std::vector fill_random(size_t length, uint64_t seed); } // namespace kai::test diff --git a/test/reference/fill.hpp b/test/reference/fill.hpp index 80093952a0c7bb7408b63dbbe59528bc105cbfb8..df5f03d36ae76651a71520addce0a6bd06d9758c 100644 --- a/test/reference/fill.hpp +++ b/test/reference/fill.hpp @@ -24,6 +24,28 @@ class DataFormat; /// @return The data buffer for the matrix. std::vector fill_matrix_random(size_t height, size_t width, const DataFormat& format, uint64_t seed); +/// Gets a random value. +/// +/// @tparam Value The data type. +/// +/// @param[in] seed The random seed. +/// +/// @return The random value. +template +Value get_random(uint64_t seed); + +/// Gets a random value. +/// +/// @tparam Value The data type. +/// +/// @param[in] seed The random seed. +/// @param[in] min_value The minimum value. +/// @param[in] max_value The maximum value. +/// +/// @return The random value. +template +Value get_random(uint64_t seed, Value min_value, Value max_value); + /// Creates a new data buffer filled with random data. /// /// @tparam Value The data type. @@ -35,4 +57,17 @@ std::vector fill_matrix_random(size_t height, size_t width, const DataF template std::vector fill_random(size_t length, uint64_t seed); +/// Creates a new data buffer filled with random data. +/// +/// @tparam Value The data type. +/// +/// @param[in] length The number of elements. +/// @param[in] seed The random seed. +/// @param[in] min_value The minimum value. +/// @param[in] max_value The maximum value. +/// +/// @return The data buffer. +template +std::vector fill_random(size_t length, uint64_t seed, Value min_value, Value max_value); + } // namespace kai::test diff --git a/test/reference/matmul_pack.cpp b/test/reference/matmul_pack.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1c638f576587a9a98e0aa4ccbad9be78e351760f --- /dev/null +++ b/test/reference/matmul_pack.cpp @@ -0,0 +1,53 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/reference/matmul_pack.hpp" + +#include +#include +#include + +#include "test/common/round.hpp" +#include "test/reference/binary_elementwise.hpp" +#include "test/reference/pack.hpp" +#include "test/reference/reduce.hpp" +#include "test/reference/reorder.hpp" + +namespace kai::test { + +template +std::vector matmul_pack_rhs_nxk_static_quantized( + const void* data, const void* scales, Scale lhs_scale, Scale dst_scale, const void* biases, + ZeroPoint lhs_zero_point, size_t n, size_t k, size_t block_height, size_t block_width) { + // The RHS data matrix is reordered according to the blocking parameters. + const auto reordered_data = reorder_block(data, n, k, block_height, block_width); + + // The effective per-channel scale: + // final_scales[n_index] = lhs_scale * rhs_scales[n_index] * dst_scale. + const auto scale_multiplier = lhs_scale * dst_scale; + auto combined_scales = mul(scales, 1, n, &scale_multiplier, 1, 1); + combined_scales.resize(round_up_multiple(n, block_height) * sizeof(Scale)); // Pads with 0s. + + // The effective per-channel biases: + // final_biases[n_index] = biases[n_index] - lhs_zero_point * sum(data[n_index, :]). + const auto row_sum = reduce_add_x(data, n, k); + const auto row_sum_times_lhs_zp = mul(row_sum.data(), n, k, &lhs_zero_point, 1, 1); + auto combined_biases = sub(biases, 1, n, row_sum_times_lhs_zp.data(), 1, n); + combined_biases.resize(round_up_multiple(n, block_height) * sizeof(ZeroPoint)); // Pads with 0s. + + // Packs the effective biases followed by the data block followed by the effective scales for the block. + auto packed_rhs = pack_zero_points_data_scales_per_block( + combined_biases.data(), reordered_data.data(), combined_scales.data(), round_up_division(n, block_height), + block_height, block_height * round_up_multiple(k, block_width), block_height); + + return packed_rhs; +} + +template std::vector matmul_pack_rhs_nxk_static_quantized( + const void* data, const void* scales, float lhs_scale, float dst_scale, const void* biases, int32_t lhs_zero_point, + size_t n, size_t k, size_t block_height, size_t block_width); + +} // namespace kai::test diff --git a/test/reference/matmul_pack.hpp b/test/reference/matmul_pack.hpp new file mode 100644 index 0000000000000000000000000000000000000000..30646c9875d8be9ef80ad47362a203ed4c68881c --- /dev/null +++ b/test/reference/matmul_pack.hpp @@ -0,0 +1,44 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include + +namespace kai::test { + +/// Packs the RHS buffer for static quantized GeMM. +/// +/// The RHS matrix must be transposed. +/// +/// This function can be used when the following conditions are met: +/// * LHS, RHS and DST data types have the same size and are quantized. +/// * LHS is asymmetric per-tensor, RHS is symmetric per-channel and DST is asymmetric per-tensor. +/// +/// @tparam Data The data type of the RHS matrix. +/// @tparam Scale The data type of the quantization scales. +/// @tparam ZeroPoint The data type of the quantization zero points and the operator biases. +/// +/// @param[in] data The data buffer of the RHS matrix. +/// @param[in] scales The quantization scales of the RHS matrix. +/// @param[in] lhs_scale The quantization scale of the LHS matrix. +/// @param[in] dst_scale The quantization scale of the DST matrix. +/// @param[in] biases The biases of the operator. +/// @param[in] lhs_zero_point The quantization zero point of the LHS matrix. +/// @param[in] n The number of columns of the non-transposed RHS matrix. +/// @param[in] k The number of rows of the non-transposed RHS matrix. +/// @param[in] block_height The number of rows of a data block (N dimension). +/// @param[in] block_width The number of columns of a data block (K dimension). +/// +/// @return The packed RHS. +template +std::vector matmul_pack_rhs_nxk_static_quantized( + const void* data, const void* scales, Scale lhs_scale, Scale dst_scale, const void* biases, + ZeroPoint lhs_zero_point, size_t n, size_t k, size_t block_height, size_t block_width); + +} // namespace kai::test diff --git a/test/reference/pack.cpp b/test/reference/pack.cpp index 221ba36079ec8d6c6bf59289826a37ec2525b300..c7a5f12df8267260ee58a6f60599e768338bddee 100644 --- a/test/reference/pack.cpp +++ b/test/reference/pack.cpp @@ -195,6 +195,56 @@ std::vector pack_data_scales( return dst; } +template +std::vector pack_zero_points_data_scales_per_block( + const void* zero_points, const void* data, const void* scales, size_t num_blocks, size_t block_num_zero_points, + size_t block_num_data, size_t block_num_scales) { + // Only data is allowed to be sub-byte. + KAI_ASSUME(size_in_bits % 8 == 0); + KAI_ASSUME(size_in_bits % 8 == 0); + + // Checks for memory alignment. + KAI_ASSUME(size_in_bits % size_in_bits == 0); + KAI_ASSUME( + (block_num_zero_points * size_in_bits + block_num_data * size_in_bits) % size_in_bits == + 0); + KAI_ASSUME( + (block_num_data * size_in_bits + block_num_scales * size_in_bits) % size_in_bits == 0); + + std::vector dst(round_up_division( + num_blocks * + (block_num_zero_points * size_in_bits + block_num_data * size_in_bits + + block_num_scales * size_in_bits), + 8)); + auto* dst_ptr = dst.data(); + + for (size_t block_no = 0; block_no < num_blocks; ++block_no) { + for (size_t i = 0; i < block_num_zero_points; ++i) { + write_array( + dst_ptr, i, read_array(zero_points, block_no * block_num_zero_points + i)); + } + dst_ptr += block_num_zero_points * sizeof(ZeroPoint); + + for (size_t i = 0; i < block_num_data; ++i) { + write_array(dst_ptr, i, read_array(data, block_no * block_num_data + i)); + } + dst_ptr += round_up_division(block_num_data * size_in_bits, 8); + + for (size_t i = 0; i < block_num_scales; ++i) { + write_array(dst_ptr, i, read_array(scales, block_no * block_num_scales + i)); + } + dst_ptr += block_num_scales * sizeof(Scale); + } + + KAI_ASSERT(dst_ptr == &*dst.end()); + + return dst; +} + +template std::vector pack_zero_points_data_scales_per_block( + const void* zero_points, const void* data, const void* scales, size_t num_blocks, size_t block_num_zero_points, + size_t block_num_data, size_t block_num_scales); + template std::vector pack_data_scales_interleave_block( const void* data, const void* scales, size_t height, size_t width, size_t quant_width) { diff --git a/test/reference/pack.hpp b/test/reference/pack.hpp index 10d76a7f6a289fea0a5217c391cea77b703cedaa..501adb262ad283214237c6663b1fcc853ae275d5 100644 --- a/test/reference/pack.hpp +++ b/test/reference/pack.hpp @@ -79,6 +79,70 @@ template std::vector pack_data_scales( const void* data, const void* scales, size_t height, size_t width, size_t quant_width); +/// Packs the zero point, data and scale into a single buffer. +/// +/// ``` +/// Data matrix: +/// +/// +-----------------+ +/// | q00 q01 q02 q03 | +/// | q10 q11 q12 q13 | +/// | q20 q21 q22 q23 | +/// | q30 q31 q32 q33 | +/// | ............... | +/// : ............... : +/// +/// Scales for each row: +/// +/// +----+ +/// | s0 | +/// | s1 | +/// | s2 | +/// | s3 | +/// | .. | +/// : .. : +/// +/// Zero points for each row: +/// +/// +----+ +/// | z0 | +/// | z1 | +/// | z2 | +/// | z3 | +/// | .. | +/// : .. : +/// ``` +/// +/// The packed data has each zero point followed by the data row followed by the scale. +/// +/// ``` +/// Packed data: +/// +/// +----+-----------------+----+ +/// | z0 | q00 q01 q02 q03 | s0 | +/// | z1 | q10 q11 q12 q13 | s1 | +/// | z2 | q20 q21 q22 q23 | s2 | +/// | z3 | q30 q31 q32 q33 | s3 | +/// | .. | ............... | .. | +/// : .. : ............... : .. : +/// ``` +/// +/// @tparam Data The data type of the data. +/// @tparam Scale The data type of the scale. +/// @tparam ZeroPoint The data type of the zero point. +/// +/// @param[in] data The data buffer. +/// @param[in] scales The scales buffer. +/// @param[in] zero_points The zero points buffer. +/// @param[in] height The number of rows. +/// @param[in] width The number of columns. +/// +/// @return The packed data buffer. +template +std::vector pack_zero_points_data_scales_per_block( + const void* zero_points, const void* data, const void* scales, size_t num_blocks, size_t block_num_zero_points, + size_t block_num_data, size_t block_num_scales); + /// Packs the quantized data and the quantization scale into a single buffer. /// /// ``` diff --git a/test/reference/reduce.cpp b/test/reference/reduce.cpp index 0e83b9bfdc7cf226635100e03cdd84c19497e207..5332cef14205fdb93154f08b4b294f147363136e 100644 --- a/test/reference/reduce.cpp +++ b/test/reference/reduce.cpp @@ -15,6 +15,7 @@ #include "test/common/data_type.hpp" #include "test/common/int4.hpp" #include "test/common/memory.hpp" +#include "test/common/round.hpp" namespace kai::test { @@ -110,4 +111,23 @@ std::vector reduce_add( return reduce_any_op(src, src_format, height, width, dst_format, dimension); } +template +std::vector reduce_add_x(const void* src, size_t height, size_t width) { + std::vector dst(round_up_division(height * size_in_bits, 8)); + + for (size_t y = 0; y < height; ++y) { + Accumulator acc = 0; + + for (size_t x = 0; x < width; ++x) { + acc += static_cast(read_array(src, y * width + x)); + } + + write_array(dst.data(), y, acc); + } + + return dst; +} + +template std::vector reduce_add_x(const void* src, size_t height, size_t width); + } // namespace kai::test diff --git a/test/reference/reduce.hpp b/test/reference/reduce.hpp index f6ba197a0f5bdcf9ef16a35fe4ee4e74cfe50450..51f18a0f2ef89f82c67f2cf29de3406b5d6ff294 100644 --- a/test/reference/reduce.hpp +++ b/test/reference/reduce.hpp @@ -33,4 +33,17 @@ std::vector reduce_add( const void* src, const DataFormat& src_format, size_t height, size_t width, const DataFormat& dst_format, size_t dimension); +/// Accumulates the matrix along the first dimension. +/// +/// @tparam Value The data type of the matrix value. +/// @tparam Accumulator The data type of the accumulator. +/// +/// @param[in] src The input data. +/// @param[in] height The number of rows of the input matrix. +/// @param[in] width The number of columns of the input matrix. +/// +/// @return The vector containing the sum of each input matrix row. +template +std::vector reduce_add_x(const void* src, size_t height, size_t width); + } // namespace kai::test diff --git a/test/reference/reorder.cpp b/test/reference/reorder.cpp new file mode 100644 index 0000000000000000000000000000000000000000..564f96f6ae6f9dbd4cd7fd7d24834bf079f6fe9e --- /dev/null +++ b/test/reference/reorder.cpp @@ -0,0 +1,50 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/reference/reorder.hpp" + +#include +#include +#include + +#include "test/common/memory.hpp" +#include "test/common/round.hpp" + +namespace kai::test { + +template +std::vector reorder_block( + const void* src, size_t height, size_t width, size_t block_height, size_t block_width) { + const auto num_dst_elements = round_up_multiple(height, block_height) * round_up_multiple(width, block_width); + const auto dst_size = round_up_division(num_dst_elements * size_in_bits, 8); + + std::vector dst(dst_size); + size_t dst_index = 0; + + for (size_t y_block = 0; y_block < height; y_block += block_height) { + for (size_t x_block = 0; x_block < width; x_block += block_width) { + for (size_t y_element = 0; y_element < block_height; ++y_element) { + for (size_t x_element = 0; x_element < block_width; ++x_element) { + const auto y = y_block + y_element; + const auto x = x_block + x_element; + + if (y < height && x < width) { + write_array(dst.data(), dst_index, read_array(src, y * width + x)); + } + + ++dst_index; + } + } + } + } + + return dst; +} + +template std::vector reorder_block( + const void* src, size_t height, size_t width, size_t block_height, size_t block_width); + +} // namespace kai::test diff --git a/test/reference/reorder.hpp b/test/reference/reorder.hpp new file mode 100644 index 0000000000000000000000000000000000000000..48449e37c197d710b5cc73b6d28f3ae862acdef0 --- /dev/null +++ b/test/reference/reorder.hpp @@ -0,0 +1,72 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include + +namespace kai::test { + +/// Reorders the input matrix block by block. +/// +/// Example: +/// +/// The input matrix: 5x7. +/// +/// ``` +/// +-----------------------------+ +/// | a00 a01 a02 a03 a04 a05 a06 | +/// | a10 a11 a12 a13 a14 a15 a16 | +/// | a20 a21 a22 a23 a24 a25 a26 | +/// | a30 a31 a32 a33 a34 a35 a36 | +/// | a40 a41 a42 a43 a44 a45 a46 | +/// +-----------------------------+ +/// ``` +/// +/// The matrix is divided into blocks of 2x3. +/// At the right and bottom edges, the partial blocks are padded with 0s. +/// +/// ``` +// +-------------+-------------+-------------+ +/// | a00 a01 a02 | a03 a04 a05 | a06 0 0 | +/// | a10 a11 a12 | a13 a14 a15 | a16 0 0 | +/// +-------------+-------------+-------------+ +/// | a20 a21 a22 | a23 a24 a25 | a26 0 0 | +/// | a30 a31 a32 | a33 a34 a35 | a36 0 0 | +/// +-------------+-------------+-------------+ +/// | a40 a41 a42 | a43 a44 a45 | a46 0 0 | +/// | 0 0 0 | 0 0 0 | 0 0 0 | +/// +-------------+-------------+-------------+ +/// ``` +/// +/// Each block is then flatten to get the final reordered matrix: +/// +/// ``` +/// +-------------------------+-------------------------+-------------------------+ +/// | a00 a01 a02 a10 a11 a12 | a03 a04 a05 a13 a14 a15 | a06 0 0 a16 0 0 | +/// +-------------------------+-------------------------+-------------------------+ +/// | a20 a21 a22 a30 a31 a32 | a23 a24 a25 a33 a34 a35 | a26 0 0 a36 0 0 | +/// +-------------------------+-------------------------+-------------------------+ +/// | a40 a41 a42 0 0 0 | a43 a44 a45 0 0 0 | a46 0 0 0 0 0 | +/// +-------------------------+-------------------------+-------------------------+ +/// +/// @tparam T The data type. +/// +/// @param[in] src The input data. +/// @param[in] height The number of rows of the input matrix. +/// @param[in] width The number of columns of the input matrix. +/// @param[in] block_height The number of rows of a block. +/// @param[in] block_width The number of columns of a block. +/// +/// @param[in] The reordered matrix. +/// ``` +template +std::vector reorder_block( + const void* src, size_t height, size_t width, size_t block_height, size_t block_width); + +} // namespace kai::test diff --git a/test/reference/transpose.cpp b/test/reference/transpose.cpp index 95cbc8e22566888fc13e7353abc6b35d5e837425..aa044caee9ebe00367678f24c4a9dbbde4d880ea 100644 --- a/test/reference/transpose.cpp +++ b/test/reference/transpose.cpp @@ -13,6 +13,8 @@ #include "kai/kai_common.h" #include "test/common/data_type.hpp" +#include "test/common/memory.hpp" +#include "test/common/round.hpp" namespace kai::test { @@ -36,4 +38,20 @@ std::vector transpose(const void* data, DataType data_type, size_t heig return output; } +template +std::vector transpose(const void* src, size_t height, size_t width) { + std::vector dst(round_up_division(height * width * size_in_bits, 8)); + + for (size_t y = 0; y < width; ++y) { + for (size_t x = 0; x < height; ++x) { + write_array(dst.data(), y * height + x, read_array(src, x * width + y)); + } + } + + return dst; +} + +template std::vector transpose(const void* src, size_t height, size_t width); +template std::vector transpose(const void* src, size_t height, size_t width); + } // namespace kai::test diff --git a/test/reference/transpose.hpp b/test/reference/transpose.hpp index 2c2f6a83cdd27f2d69010449d72899a61c798e2e..ba6f5648c74d2d18bc43ca0785c12b54135f8ac6 100644 --- a/test/reference/transpose.hpp +++ b/test/reference/transpose.hpp @@ -24,4 +24,16 @@ namespace kai::test { /// @return The transposed matrix. std::vector transpose(const void* data, DataType data_type, size_t height, size_t width); +/// Transposes the matrix. +/// +/// @tparam T The data type. +/// +/// @param[in] src The data buffer of the source matrix. +/// @param[in] height The number of rows of the source matrix. +/// @param[in] width The number of columns of the source matrix. +/// +/// @return The transposed matrix. +template +std::vector transpose(const void* src, size_t height, size_t width); + } // namespace kai::test diff --git a/test/tests/matmul_clamp_qai8_qai8p_qsi8cp_test.cpp b/test/tests/matmul_clamp_qai8_qai8p_qsi8cp_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a66864a72dd7a78fc3e15ffb095c4866a9e5177b --- /dev/null +++ b/test/tests/matmul_clamp_qai8_qai8p_qsi8cp_test.cpp @@ -0,0 +1,117 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include +#include + +#include "kai/kai_common.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cp2vlx4sb_qsi8_f32_i32_sme.h" +#include "test/common/sme.hpp" +#include "test/reference/fill.hpp" +#include "test/reference/matmul_pack.hpp" +#include "test/reference/transpose.hpp" + +namespace kai::test { + +namespace { + +struct GemmVariant { + size_t acc_height; + size_t acc_width; + size_t acc_fanin; + + size_t (*fn_pack_rhs_get_packed_rhs_offset)(size_t n, size_t k); + void (*fn_pack_rhs_run)( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, + const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, + const struct kai_rhs_pack_qsi8_params* params); +}; + +struct GemmShape { + size_t m; + size_t n; + size_t k; +}; + +const std::array gemm_variants = { + GemmVariant{ + .acc_height = 2 * get_sme_vector_length(), + .acc_width = 2 * get_sme_vector_length(), + .acc_fanin = sizeof(int32_t) / sizeof(int8_t), + + .fn_pack_rhs_get_packed_rhs_offset = kai_get_rhs_packed_size_rhs_pack_kxn_qsi8cp2vlx4sb_qsi8_f32_i32_sme, + .fn_pack_rhs_run = kai_run_rhs_pack_kxn_qsi8cp2vlx4sb_qsi8_f32_i32_sme, + }, +}; + +const std::array gemm_shapes = { + GemmShape{1, 1, 1}, // + GemmShape{ + 2 * get_sme_vector_length(), 2 * get_sme_vector_length(), + sizeof(int32_t) / sizeof(int8_t)}, // + GemmShape{20, 30, 40}, // + GemmShape{1, 49, 21}, // + GemmShape{23, 1, 43}, // + GemmShape{32, 14, 1}, // + GemmShape{123, 85, 45}, // +}; + +void run_test(const GemmShape& shape, const GemmVariant& variant) { + const uint64_t seed = 0; + + // Generates the input data. + const auto rhs_qsi8 = fill_random(shape.k * shape.n, seed + 0); + const auto rhs_qsi8_scales = fill_random(shape.n, seed + 1); + const auto biases_qsi32 = fill_random(shape.n, seed + 2, -1000, 1000); + + const auto lhs_zero_point = get_random(seed + 3, -1000, 1000); + const auto lhs_scale = get_random(seed + 4) * 0; + const auto dst_scale = get_random(seed + 5); + + // Runs the reference implementation. + const auto ref_rhs_qsi8_t = transpose(rhs_qsi8.data(), shape.k, shape.n); + const auto ref_packed_rhs = matmul_pack_rhs_nxk_static_quantized( + ref_rhs_qsi8_t.data(), rhs_qsi8_scales.data(), lhs_scale, dst_scale, biases_qsi32.data(), lhs_zero_point, + shape.n, shape.k, variant.acc_width, variant.acc_fanin); + + // Runs the implementation under test. + const auto imp_packed_rhs_size = variant.fn_pack_rhs_get_packed_rhs_offset(shape.n, shape.k); + ASSERT_EQ(imp_packed_rhs_size, ref_packed_rhs.size()); + std::vector imp_packed_rhs(imp_packed_rhs_size); + + kai_rhs_pack_qsi8_params imp_params{ + .input_zero_point = lhs_zero_point, + .scale_multiplier = lhs_scale * dst_scale, + }; + + variant.fn_pack_rhs_run( + 1, shape.n, shape.k, variant.acc_width, variant.acc_fanin, 1, shape.n * sizeof(int8_t), rhs_qsi8.data(), + biases_qsi32.data(), rhs_qsi8_scales.data(), imp_packed_rhs.data(), 0, &imp_params); + + for (size_t i = 0; i < ref_packed_rhs.size(); ++i) { + ASSERT_EQ(imp_packed_rhs[i], ref_packed_rhs[i]); + } +} + +using ThisTest = testing::TestWithParam>; + +TEST_P(ThisTest, EndToEnd) { + const auto& [variant, shape] = GetParam(); + + run_test(shape, variant); +} + +} // namespace + +INSTANTIATE_TEST_SUITE_P( + matmul_clamp_qai8_qai8p_qsi8cp, ThisTest, + testing::Combine(testing::ValuesIn(gemm_variants), testing::ValuesIn(gemm_shapes))); + +} // namespace kai::test