diff --git a/CMakeLists.txt b/CMakeLists.txt index b9720632949f9d7903f817af7d44ef9bf46f874b..cb8c6a5f4993ecfe126e34643c64b17a6c767ccd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -146,6 +146,8 @@ set(KLEIDIAI_FILES_SME kai/ukernels/matmul/pack/kai_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.c + kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.c + kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c ) set(KLEIDIAI_FILES_SME2 @@ -154,6 +156,7 @@ set(KLEIDIAI_FILES_SME2 kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.c kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot.c + kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c ) add_library(kleidiai) @@ -219,13 +222,16 @@ if(KLEIDIAI_BUILD_TESTS) test/reference/binary_elementwise.cpp test/reference/matmul.cpp + test/reference/matmul_pack.cpp test/reference/fill.cpp test/reference/pack.cpp test/reference/pad.cpp + test/reference/clamp.cpp test/reference/quantize.cpp test/reference/reduce.cpp test/reference/transpose.cpp test/reference/cast.cpp + test/reference/reorder.cpp ) target_compile_options(kleidiai_test_framework @@ -248,6 +254,7 @@ if(KLEIDIAI_BUILD_TESTS) test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp test/tests/matmul_clamp_f16_bf16p_bf16p_test.cpp test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp + test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp ) target_link_libraries(kleidiai_test diff --git a/docs/imgs/kai_rhs_packing_pattern_2.png b/docs/imgs/kai_rhs_packing_pattern_2.png new file mode 100644 index 0000000000000000000000000000000000000000..d10d1fc038c8f118dbb07d316f7eb49dfdffe8ac Binary files /dev/null and b/docs/imgs/kai_rhs_packing_pattern_2.png differ diff --git a/docs/imgs/kai_rhs_packing_pattern_2.png.license b/docs/imgs/kai_rhs_packing_pattern_2.png.license new file mode 100644 index 0000000000000000000000000000000000000000..efa11a946326059bc0e2bd4bc256e14362240ee2 --- /dev/null +++ b/docs/imgs/kai_rhs_packing_pattern_2.png.license @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates + +# SPDX-License-Identifier: Apache-2.0 diff --git a/kai/kai_common.h b/kai/kai_common.h index 332528504c2802abec6de52975a9957e04a4c53c..cbd03a95c6bf16e8ca9183ef04c8979f86caed59 100644 --- a/kai/kai_common.h +++ b/kai/kai_common.h @@ -175,6 +175,19 @@ struct kai_rhs_pack_qs4cxs1s0_param { uint8_t rhs_zero_point; /**< RHS Matrix quantization zero-point */ }; +/// RHS packing parameter for 8-bit quantization. +struct kai_rhs_pack_qsi8_params { + int32_t lhs_zero_point; ///< LHS quantization zero point. + float scale_multiplier; ///< Product of input (refers to lhs and rhs) and output quantization scales. +}; + +/// Requantization and clamp parameters for GEMM/GEMV output stage. +struct kai_matmul_requantize32_params { + int32_t min_value; ///< Minimum output value. + int32_t max_value; ///< Maximum output value. + int32_t output_zero_point; ///< Output quantization zero point. +}; + #ifdef __cplusplus } #endif diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 4819368624f33e1c826363691021b5486e69a704..2d1134c7993b70b4a0cc7b2f1cae622550d9e416 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -152,6 +152,13 @@ kai_c_library( ], ) +kai_c_library( + name = "clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa", + srcs = ["matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c"], + hdrs = ["matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h"], + cpu_uarch = kai_cpu_sme(), +) + cc_library( name = "clamp_f32_qai8dxp_qsi4cxp_interface", hdrs = ["matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp_qsi4cxp_interface.h"], @@ -312,6 +319,13 @@ kai_c_library( cpu_uarch = kai_cpu_bf16(), ) +kai_c_library( + name = "lhs_pack_x8p2vlx4_x8_sme", + srcs = ["pack/kai_lhs_pack_x8p2vlx4_x8_sme.c"], + hdrs = ["pack/kai_lhs_pack_x8p2vlx4_x8_sme.h"], + cpu_uarch = kai_cpu_sme(), +) + kai_c_library( name = "rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon", srcs = ["pack/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.c"], @@ -368,6 +382,13 @@ kai_c_library( cpu_uarch = kai_cpu_sme(), ) +kai_c_library( + name = "rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme", + srcs = ["pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c"], + hdrs = ["pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h"], + cpu_uarch = kai_cpu_sme(), +) + kai_c_library( name = "rhs_pack_nxk_qsi4cxp_qs4cxs1s0", srcs = ["pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.c"], @@ -540,10 +561,12 @@ kai_c_library( ":clamp_f32_qsi8d32p_qsi4c32p_dotprod", ":clamp_f32_qsi8d32p_qsi4c32p_i8mm", ":clamp_f32_qsi8d32p_qsi4c32p_interface", + ":clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa", ":kai_files_sme", ":kai_files_sme2", ":lhs_pack_bf16p8x4_f16_neon", ":lhs_pack_f32p2vlx1_f32_sme", + ":lhs_pack_x8p2vlx4_x8_sme", ":lhs_quant_pack_bf16p1x4_f32_neon", ":lhs_quant_pack_bf16p8x4_f32_neon", ":lhs_quant_pack_qai8dxp_f32", @@ -556,6 +579,7 @@ kai_c_library( ":rhs_pack_kxn_f32pbiasf32_f32_f32_neon", ":rhs_pack_kxn_qsi4c32p_qsu4c32s1s0", ":rhs_pack_kxn_qsi4cxp_qs4cxs1s0", + ":rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme", ":rhs_pack_kxn_qsi8cxp_qsi8cx_neon", ":rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme", ":rhs_pack_nxk_qsi4c32p_qsu4c32s1s0", diff --git a/kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c b/kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c new file mode 100644 index 0000000000000000000000000000000000000000..ef38be5f92a39e09deec0dd13dd956444c8b869a --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c @@ -0,0 +1,399 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural features check. + +#include "kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h" + +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_mr = 2; +static const size_t kai_nr = 2; +static const size_t kai_kr = 4; +static const size_t kai_sr = 1; + +size_t kai_get_m_step_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(void) { + return kai_mr * kai_get_sme_vector_length_u32(); +} + +size_t kai_get_n_step_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(void) { + return kai_nr * kai_get_sme_vector_length_u32(); +} + +size_t kai_get_mr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(void) { + return kai_mr * kai_get_sme_vector_length_u32(); +} + +size_t kai_get_nr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(void) { + return kai_nr * kai_get_sme_vector_length_u32(); +} + +size_t kai_get_kr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(size_t m_idx, size_t k) { + KAI_ASSUME(m_idx % kai_get_m_step_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa() == 0); + return m_idx * kai_roundup(k, kai_kr) * sizeof(int8_t); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(size_t n_idx, size_t k) { + KAI_ASSUME(n_idx % kai_get_n_step_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa() == 0); + return n_idx * (sizeof(int32_t) + kai_roundup(k, kai_kr) * sizeof(int8_t) + sizeof(float)); +} + +size_t kai_get_dst_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSUME(m_idx % kai_get_m_step_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa() == 0); + KAI_ASSUME(n_idx % kai_get_n_step_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa() == 0); + + return m_idx * dst_stride + n_idx * sizeof(int8_t); +} + +size_t kai_get_dst_size_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(size_t m, size_t n) { + return m * n * sizeof(int8_t); +} + +void kai_run_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( + size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row, + size_t dst_stride_col, const struct kai_matmul_requantize32_params* params) { + KAI_ASSUME(dst_stride_col == sizeof(int8_t)); + + typedef struct { + const void* A; + const void* B; + + void* C; + uint64_t ldcb; + uint64_t M, N, K; + int32_t min; + int32_t max; + int32_t result_zero_point; + + void* accumulator_buffer; + uint64_t flags; + } KernelArgs; + + KernelArgs args; + + args.A = lhs_packed; + args.B = rhs_packed; + + args.C = dst; + args.ldcb = dst_stride_row; + args.M = m; + args.N = n; + args.K = k; + args.min = params->min_value; + args.max = params->max_value; + args.result_zero_point = params->output_zero_point; + + args.accumulator_buffer = NULL; + args.flags = 0; + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "ldr w14, [%x[args], %[offsetof_M]]\n" + "mov x13, #0x0\n" + "mov x11, #0x0\n" + "ptrue p1.b\n" + ".inst 0x25207811 // ptrue pn9.b\n" + "ldr w10, [%x[args], %[offsetof_N]]\n" + "ldr x9, [%x[args], %[offsetof_A]]\n" + "1:" // M loop + "ldr x28, [%x[args], %[offsetof_B]]\n" + "2:" // N loop + ".inst 0x25aa4570 // whilelt pn8.s, x11, x10, VLx2\n" + ".inst 0xc00800ff // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 }\n" + "mov x27, x9\n" + ".inst 0xa040438e // ld1w { z14.s-z15.s }, p8/Z, [x28]\n" // Load bias + "addvl x28, x28, #2\n" + ".inst 0xc09025c0 // addha za0.s, p1/M, p1/M, z14.s\n" + ".inst 0xc09025e1 // addha za1.s, p1/M, p1/M, z15.s\n" + ".inst 0xc09025c2 // addha za2.s, p1/M, p1/M, z14.s\n" + ".inst 0xc09025e3 // addha za3.s, p1/M, p1/M, z15.s\n" + "ldr x20, [%x[args], %[offsetof_K]]\n" + "add x20, x20, #0x3\n" + "lsr x20, x20, #0x2\n" + "lsr x21, x20, #0x2\n" + "and x20, x20, #0x3\n" + "cbz x21, 6f\n" + "subs x21, x21, #0x1\n" + ".inst 0xa0400762 // ld1b { z2.b-z3.b }, pn9.b/Z, [x27]\n" + ".inst 0xa1400780 // ld1b { z0.b, z8.b }, pn9.b/Z, [x28]\n" + ".inst 0xa0410772 // ld1b { z18.b-z19.b }, pn9.b/Z, [x27, #0x2, MUL VL]\n" + ".inst 0xa0410794 // ld1b { z20.b-z21.b }, pn9.b/Z, [x28, #0x2, MUL VL]\n" + ".inst 0xa042077a // ld1b { z26.b-z27.b }, pn9.b/Z, [x27, #0x4, MUL VL]\n" + ".inst 0xa0420796 // ld1b { z22.b-z23.b }, pn9.b/Z, [x28, #0x4, MUL VL]\n" + ".inst 0xa0430778 // ld1b { z24.b-z25.b }, pn9.b/Z, [x27, #0x6, MUL VL]\n" + "addvl x27, x27, #8\n" + ".inst 0xa0430784 // ld1b { z4.b-z5.b }, pn9.b/Z, [x28, #0x6, MUL VL]\n" + "addvl x28, x28, #8\n" + "ble 5f\n" + "4:" // K loop + ".inst 0xa0802440 // smopa za0.s, p1/M, p1/M, z2.b, z0.b\n" + "subs x21, x21, #0x1\n" + ".inst 0xa0882441 // smopa za1.s, p1/M, p1/M, z2.b, z8.b\n" + ".inst 0xa0802462 // smopa za2.s, p1/M, p1/M, z3.b, z0.b\n" + ".inst 0xa0882463 // smopa za3.s, p1/M, p1/M, z3.b, z8.b\n" + ".inst 0xa0400762 // ld1b { z2.b-z3.b }, pn9.b/Z, [x27]\n" + ".inst 0xa0942640 // smopa za0.s, p1/M, p1/M, z18.b, z20.b\n" + ".inst 0xa1400780 // ld1b { z0.b, z8.b }, pn9.b/Z, [x28]\n" + ".inst 0xa0952641 // smopa za1.s, p1/M, p1/M, z18.b, z21.b\n" + ".inst 0xa0942662 // smopa za2.s, p1/M, p1/M, z19.b, z20.b\n" + ".inst 0xa0952663 // smopa za3.s, p1/M, p1/M, z19.b, z21.b\n" + ".inst 0xa0410772 // ld1b { z18.b-z19.b }, pn9.b/Z, [x27, #0x2, MUL VL]\n" + ".inst 0xa0962740 // smopa za0.s, p1/M, p1/M, z26.b, z22.b\n" + ".inst 0xa0410794 // ld1b { z20.b-z21.b }, pn9.b/Z, [x28, #0x2, MUL VL]\n" + ".inst 0xa0972741 // smopa za1.s, p1/M, p1/M, z26.b, z23.b\n" + ".inst 0xa0962762 // smopa za2.s, p1/M, p1/M, z27.b, z22.b\n" + ".inst 0xa0972763 // smopa za3.s, p1/M, p1/M, z27.b, z23.b\n" + ".inst 0xa042077a // ld1b { z26.b-z27.b }, pn9.b/Z, [x27, #0x4, MUL VL]\n" + ".inst 0xa0420796 // ld1b { z22.b-z23.b }, pn9.b/Z, [x28, #0x4, MUL VL]\n" + ".inst 0xa0842700 // smopa za0.s, p1/M, p1/M, z24.b, z4.b\n" + ".inst 0xa0852701 // smopa za1.s, p1/M, p1/M, z24.b, z5.b\n" + ".inst 0xa0842722 // smopa za2.s, p1/M, p1/M, z25.b, z4.b\n" + ".inst 0xa0852723 // smopa za3.s, p1/M, p1/M, z25.b, z5.b\n" + ".inst 0xa0430778 // ld1b { z24.b-z25.b }, pn9.b/Z, [x27, #0x6, MUL VL]\n" + "addvl x27, x27, #8\n" + ".inst 0xa0430784 // ld1b { z4.b-z5.b }, pn9.b/Z, [x28, #0x6, MUL VL]\n" + "addvl x28, x28, #8\n" + "bgt 4b\n" + "5:" // K loop tail + ".inst 0xa0802440 // smopa za0.s, p1/M, p1/M, z2.b, z0.b\n" + ".inst 0xa0882441 // smopa za1.s, p1/M, p1/M, z2.b, z8.b\n" + ".inst 0xa0802462 // smopa za2.s, p1/M, p1/M, z3.b, z0.b\n" + ".inst 0xa0882463 // smopa za3.s, p1/M, p1/M, z3.b, z8.b\n" + ".inst 0xa0942640 // smopa za0.s, p1/M, p1/M, z18.b, z20.b\n" + ".inst 0xa0952641 // smopa za1.s, p1/M, p1/M, z18.b, z21.b\n" + ".inst 0xa0942662 // smopa za2.s, p1/M, p1/M, z19.b, z20.b\n" + ".inst 0xa0952663 // smopa za3.s, p1/M, p1/M, z19.b, z21.b\n" + ".inst 0xa0962740 // smopa za0.s, p1/M, p1/M, z26.b, z22.b\n" + ".inst 0xa0972741 // smopa za1.s, p1/M, p1/M, z26.b, z23.b\n" + ".inst 0xa0962762 // smopa za2.s, p1/M, p1/M, z27.b, z22.b\n" + ".inst 0xa0972763 // smopa za3.s, p1/M, p1/M, z27.b, z23.b\n" + ".inst 0xa0842700 // smopa za0.s, p1/M, p1/M, z24.b, z4.b\n" + ".inst 0xa0852701 // smopa za1.s, p1/M, p1/M, z24.b, z5.b\n" + ".inst 0xa0842722 // smopa za2.s, p1/M, p1/M, z25.b, z4.b\n" + ".inst 0xa0852723 // smopa za3.s, p1/M, p1/M, z25.b, z5.b\n" + "6:" // K oddments + "cbz x20, 8f\n" + "7:" // K oddments: Loop + ".inst 0xa0400770 // ld1b { z16.b-z17.b }, pn9.b/Z, [x27]\n" + "subs x20, x20, #0x1\n" + "addvl x27, x27, #2\n" + ".inst 0xa0400788 // ld1b { z8.b-z9.b }, pn9.b/Z, [x28]\n" + "addvl x28, x28, #2\n" + ".inst 0xa0882600 // smopa za0.s, p1/M, p1/M, z16.b, z8.b\n" + ".inst 0xa0892601 // smopa za1.s, p1/M, p1/M, z16.b, z9.b\n" + ".inst 0xa0882622 // smopa za2.s, p1/M, p1/M, z17.b, z8.b\n" + ".inst 0xa0892623 // smopa za3.s, p1/M, p1/M, z17.b, z9.b\n" + "bgt 7b\n" + "8:" // K oddments: End + "ldr x26, [%x[args], %[offsetof_C]]\n" + "sub x25, x14, x13\n" + "cntw x24\n" + "ld1rw { z27.s }, p1/Z, [%x[args], %[offsetof_KernelArgs_min]]\n" + "ldr x23, [%x[args], %[offsetof_ldcb]]\n" + "whilelt p0.h, x11, x10\n" + "cmp x25, x24\n" + "ld1rw { z1.s }, p1/Z, [%x[args], %[offsetof_KernelArgs_max]]\n" + "csel x22, x25, x24, LT\n" + "ld1rw { z0.s }, p1/Z, [%x[args], %[offsetof_KernelArgs_result_zero_point]]\n" + "mov x12, #0x0\n" + "add x26, x26, x11\n" // C += n + "lsr x21, x22, #0x2\n" + "ld1w { z22.s }, p1/Z, [x28]\n" + "madd x26, x13, x23, x26\n" // C += m * ldc + "ld1w { z26.s }, p1/Z, [x28, #1, MUL VL]\n" + "and x20, x22, #0x3\n" + "addvl x28, x28, #2\n" + "cbz x21, 11f\n" + "10:" // Store to output array: Accumulator row 0 loop + ".inst 0xc0860410 // mova { z16.s-z19.s }, za0h.s[x12]\n" + ".inst 0xc086043c // mova { z28.s-z31.s }, za1h.s[x12]\n" + ".inst 0xc132e210 // scvtf { z16.s-z19.s }, { z16.s-z19.s }\n" + ".inst 0xc132e39c // scvtf { z28.s-z31.s }, { z28.s-z31.s }\n" + "fmul z16.s, z16.s, z22.s\n" + "fmul z17.s, z17.s, z22.s\n" + "add x12, x12, #0x4\n" + "fmul z18.s, z18.s, z22.s\n" + "fmul z19.s, z19.s, z22.s\n" + "cmp x12, x21, LSL #2\n" + "fmul z28.s, z28.s, z26.s\n" + "fmul z29.s, z29.s, z26.s\n" + "fmul z30.s, z30.s, z26.s\n" + "fmul z31.s, z31.s, z26.s\n" + ".inst 0xc1b8e210 // frintn { z16.s-z19.s }, { z16.s-z19.s }\n" + ".inst 0xc131e210 // fcvtzs { z16.s-z19.s }, { z16.s-z19.s }\n" + ".inst 0xc1b8e39c // frintn { z28.s-z31.s }, { z28.s-z31.s }\n" + ".inst 0xc1a0ab10 // add { z16.s-z19.s }, { z16.s-z19.s }, z0.s\n" + ".inst 0xc131e39c // fcvtzs { z28.s-z31.s }, { z28.s-z31.s }\n" + ".inst 0xc1a0ab1c // add { z28.s-z31.s }, { z28.s-z31.s }, z0.s\n" + ".inst 0xc1a1cf70 // sclamp { z16.s-z19.s }, z27.s, z1.s\n" + ".inst 0xc1a1cf7c // sclamp { z28.s-z31.s }, z27.s, z1.s\n" + "uzp1 z5.h, z16.h, z28.h\n" + "uzp1 z20.h, z17.h, z29.h\n" + "uzp1 z17.h, z18.h, z30.h\n" + "uzp1 z16.h, z19.h, z31.h\n" + "st1b { z5.h }, p0, [x26]\n" + "add x26, x26, x23\n" + "st1b { z20.h }, p0, [x26]\n" + "add x26, x26, x23\n" + "st1b { z17.h }, p0, [x26]\n" + "add x26, x26, x23\n" + "st1b { z16.h }, p0, [x26]\n" + "add x26, x26, x23\n" + "blt 10b\n" + "11:" // Store to output array: Accumulator row 0 oddments + "cbz x20, 12f\n" + ".inst 0xc0860404 // mova { z4.s-z7.s }, za0h.s[x12]\n" + ".inst 0xc086042c // mova { z12.s-z15.s }, za1h.s[x12]\n" + ".inst 0xc132e084 // scvtf { z4.s-z7.s }, { z4.s-z7.s }\n" + ".inst 0xc132e18c // scvtf { z12.s-z15.s }, { z12.s-z15.s }\n" + "fmul z4.s, z4.s, z22.s\n" + "fmul z5.s, z5.s, z22.s\n" + "subs x20, x20, #0x1\n" + "fmul z6.s, z6.s, z22.s\n" + "fmul z7.s, z7.s, z22.s\n" + "fmul z12.s, z12.s, z26.s\n" + "fmul z13.s, z13.s, z26.s\n" + "fmul z14.s, z14.s, z26.s\n" + "fmul z15.s, z15.s, z26.s\n" + ".inst 0xc1b8e084 // frintn { z4.s-z7.s }, { z4.s-z7.s }\n" + ".inst 0xc131e084 // fcvtzs { z4.s-z7.s }, { z4.s-z7.s }\n" + ".inst 0xc1b8e18c // frintn { z12.s-z15.s }, { z12.s-z15.s }\n" + ".inst 0xc1a0ab04 // add { z4.s-z7.s }, { z4.s-z7.s }, z0.s\n" + ".inst 0xc131e18c // fcvtzs { z12.s-z15.s }, { z12.s-z15.s }\n" + ".inst 0xc1a0ab0c // add { z12.s-z15.s }, { z12.s-z15.s }, z0.s\n" + ".inst 0xc1a1cf64 // sclamp { z4.s-z7.s }, z27.s, z1.s\n" + ".inst 0xc1a1cf6c // sclamp { z12.s-z15.s }, z27.s, z1.s\n" + "uzp1 z16.h, z4.h, z12.h\n" + "st1b { z16.h }, p0, [x26]\n" + "add x26, x26, x23\n" + "beq 12f\n" + "subs x20, x20, #0x1\n" + "uzp1 z16.h, z5.h, z13.h\n" + "st1b { z16.h }, p0, [x26]\n" + "add x26, x26, x23\n" + "beq 12f\n" + "uzp1 z16.h, z6.h, z14.h\n" + "st1b { z16.h }, p0, [x26]\n" + "add x26, x26, x23\n" + "12:" // Store to output array: Accumulator row 0 oddments: End + "subs x25, x25, x22\n" + "beq 16f\n" + "cmp x25, x24\n" + "mov x12, #0x0\n" + "csel x20, x25, x24, LT\n" + "lsr x21, x20, #0x2\n" + "and x20, x20, #0x3\n" + "cbz x21, 14f\n" + "13:" // Store to output array: Accumulator row 1 loop + ".inst 0xc0860448 // mova { z8.s-z11.s }, za2h.s[x12]\n" + ".inst 0xc0860470 // mova { z16.s-z19.s }, za3h.s[x12]\n" + ".inst 0xc132e108 // scvtf { z8.s-z11.s }, { z8.s-z11.s }\n" + ".inst 0xc132e210 // scvtf { z16.s-z19.s }, { z16.s-z19.s }\n" + "fmul z8.s, z8.s, z22.s\n" + "fmul z9.s, z9.s, z22.s\n" + "add x12, x12, #0x4\n" + "fmul z10.s, z10.s, z22.s\n" + "fmul z11.s, z11.s, z22.s\n" + "cmp x12, x21, LSL #2\n" + "fmul z16.s, z16.s, z26.s\n" + "fmul z17.s, z17.s, z26.s\n" + "fmul z18.s, z18.s, z26.s\n" + "fmul z19.s, z19.s, z26.s\n" + ".inst 0xc1b8e108 // frintn { z8.s-z11.s }, { z8.s-z11.s }\n" + ".inst 0xc131e108 // fcvtzs { z8.s-z11.s }, { z8.s-z11.s }\n" + ".inst 0xc1b8e210 // frintn { z16.s-z19.s }, { z16.s-z19.s }\n" + ".inst 0xc1a0ab08 // add { z8.s-z11.s }, { z8.s-z11.s }, z0.s\n" + ".inst 0xc131e210 // fcvtzs { z16.s-z19.s }, { z16.s-z19.s }\n" + ".inst 0xc1a0ab10 // add { z16.s-z19.s }, { z16.s-z19.s }, z0.s\n" + ".inst 0xc1a1cf68 // sclamp { z8.s-z11.s }, z27.s, z1.s\n" + ".inst 0xc1a1cf70 // sclamp { z16.s-z19.s }, z27.s, z1.s\n" + "uzp1 z21.h, z8.h, z16.h\n" + "uzp1 z20.h, z9.h, z17.h\n" + "uzp1 z17.h, z10.h, z18.h\n" + "uzp1 z16.h, z11.h, z19.h\n" + "st1b { z21.h }, p0, [x26]\n" + "add x26, x26, x23\n" + "st1b { z20.h }, p0, [x26]\n" + "add x26, x26, x23\n" + "st1b { z17.h }, p0, [x26]\n" + "add x26, x26, x23\n" + "st1b { z16.h }, p0, [x26]\n" + "add x26, x26, x23\n" + "blt 13b\n" + "14:" // Store to output array: Accumulator row 1 oddments + "cbz x20, 15f\n" + ".inst 0xc086044c // mova { z12.s-z15.s }, za2h.s[x12]\n" + ".inst 0xc0860464 // mova { z4.s-z7.s }, za3h.s[x12]\n" + ".inst 0xc132e18c // scvtf { z12.s-z15.s }, { z12.s-z15.s }\n" + ".inst 0xc132e084 // scvtf { z4.s-z7.s }, { z4.s-z7.s }\n" + "fmul z12.s, z12.s, z22.s\n" + "fmul z13.s, z13.s, z22.s\n" + "subs x20, x20, #0x1\n" + "fmul z14.s, z14.s, z22.s\n" + "fmul z15.s, z15.s, z22.s\n" + "fmul z4.s, z4.s, z26.s\n" + "fmul z5.s, z5.s, z26.s\n" + "fmul z6.s, z6.s, z26.s\n" + "fmul z7.s, z7.s, z26.s\n" + ".inst 0xc1b8e18c // frintn { z12.s-z15.s }, { z12.s-z15.s }\n" + ".inst 0xc131e18c // fcvtzs { z12.s-z15.s }, { z12.s-z15.s }\n" + ".inst 0xc1b8e084 // frintn { z4.s-z7.s }, { z4.s-z7.s }\n" + ".inst 0xc1a0ab0c // add { z12.s-z15.s }, { z12.s-z15.s }, z0.s\n" + ".inst 0xc131e084 // fcvtzs { z4.s-z7.s }, { z4.s-z7.s }\n" + ".inst 0xc1a0ab04 // add { z4.s-z7.s }, { z4.s-z7.s }, z0.s\n" + ".inst 0xc1a1cf6c // sclamp { z12.s-z15.s }, z27.s, z1.s\n" + ".inst 0xc1a1cf64 // sclamp { z4.s-z7.s }, z27.s, z1.s\n" + "uzp1 z16.h, z12.h, z4.h\n" + "st1b { z16.h }, p0, [x26]\n" + "add x26, x26, x23\n" + "beq 15f\n" + "subs x20, x20, #0x1\n" + "uzp1 z16.h, z13.h, z5.h\n" + "st1b { z16.h }, p0, [x26]\n" + "add x26, x26, x23\n" + "beq 15f\n" + "uzp1 z16.h, z14.h, z6.h\n" + "st1b { z16.h }, p0, [x26]\n" + "15:" // Store to output array: Accumulator row 1 oddments: End + "16:" // Store to output array: End + "incw x11, ALL, MUL #2\n" + "cmp x11, x10\n" + "blt 2b\n" + "incw x13, ALL, MUL #2\n" + "mov x11, #0x0\n" + "cmp x13, x14\n" + "mov x9, x27\n" + "blt 1b\n" + ".inst 0xd503467f // SMSTOP\n" + : + : [args] "r"(&args), [offsetof_A] "I"(offsetof(KernelArgs, A)), [offsetof_B] "I"(offsetof(KernelArgs, B)), + [offsetof_C] "I"(offsetof(KernelArgs, C)), [offsetof_K] "I"(offsetof(KernelArgs, K)), + [offsetof_KernelArgs_max] "I"(offsetof(KernelArgs, max)), + [offsetof_KernelArgs_min] "I"(offsetof(KernelArgs, min)), + [offsetof_KernelArgs_result_zero_point] "I"(offsetof(KernelArgs, result_zero_point)), + [offsetof_M] "I"(offsetof(KernelArgs, M)), [offsetof_N] "I"(offsetof(KernelArgs, N)), + [offsetof_ldcb] "I"(offsetof(KernelArgs, ldcb)) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", + "p15", "x9", "x10", "x11", "x12", "x13", "x14", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", + "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", + "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h b/kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h new file mode 100644 index 0000000000000000000000000000000000000000..e8e2868b5d2ccc8b15f56404f34c2a94e6a729bd --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h @@ -0,0 +1,122 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#include "kai/kai_common.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Micro-kernel dependencies +/// +/// -# kai_lhs_pack_x8p2vlx4_x8_sme to pack the LHS matrix. +/// -# kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme to pack the RHS matrix. + +/// Gets m step value. +/// +/// The starting row index must be divisible by `m_step`. +/// +/// @return The m step value. +size_t kai_get_m_step_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(void); + +/// Gets n step value. +/// +/// The starting column index must be divisible by `n_step`. +/// +/// @return The n step value. +size_t kai_get_n_step_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(void); + +/// Gets mr value. +/// +/// This is the packing parameter which must be used to pack the LHS matrix. +/// +/// @return The mr value. +size_t kai_get_mr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(void); + +/// Gets nr value. +/// +/// This is the packing parameter which must be used to pack the RHS matrix. +/// +/// @return The nr value. +size_t kai_get_nr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(void); + +/// Gets kr value. +/// +/// This is the packing parameter which must be used to pack the LHS and RHS matrix. +/// +/// @return The kr value. +size_t kai_get_kr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(void); + +/// Gets sr value. +/// +/// This is the packing parameter which must be used to pack the LHS and RHS matrix. +/// +/// @return The sr value. +size_t kai_get_sr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(void); + +/// Gets the offset in bytes to the data element in the packed LHS matrix buffer. +/// +/// @param[in] m_idx Row index in the unpacked LHS matrix. +/// @param[in] k Number of columns in the unpacked LHS matrix. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_lhs_packed_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(size_t m_idx, size_t k); + +/// Gets the offset in bytes to the data element in the packed RHS matrix buffer. +/// +/// @param[in] n_idx Column index in the unpacked RHS matrix. +/// @param[in] k Number of rows in the unpacked RHS matrix. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_packed_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(size_t n_idx, size_t k); + +/// Gets the offset in bytes to the data element in the destination matrix buffer. +/// +/// @param[in] m_idx Row index. +/// @param[in] n_idx Column index. +/// @param[in] dst_stride Row stride in bytes. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_dst_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( + size_t m_idx, size_t n_idx, size_t dst_stride); + +/// Gets the size in bytes of the destination matrix buffer. +/// +/// @param[in] m Number of rows. +/// @param[in] n Number of columns. +/// +/// @return The size in bytes of the destination matrix buffer. +size_t kai_get_dst_size_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(size_t m, size_t n); + +/// Runs the matrix multiplication microkernel followed by a clamp operation. +/// +/// The pointer of each buffers (packed LHS, packed RHS and output) needs to be added with offset +/// calculated using the following functions: +/// +/// * Packed LHS: @ref kai_get_lhs_packed_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa. +/// * Packed RHS: @ref kai_get_rhs_packed_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa. +/// * Output: @ref kai_get_dst_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa. +/// +/// @param[in] m Number of output rows to be computed. +/// @param[in] n Number of output columns to be computed. +/// @param[in] k Common dimension of the LHS and RHS operands. +/// @param[in] packed_lhs Packed LHS matrix buffer. +/// @param[in] packed_rhs Packed RHS matrix buffer. +/// @param[out] dst Output matrix buffer. +/// @param[in] dst_stride_row Row stride in bytes of the output matrix. +/// @param[in] dst_stride_col Column stride in bytes of the output matrix. +/// @param[in] params Requantization and clamp parmaters. +void kai_run_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( + size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row, + size_t dst_stride_col, const struct kai_matmul_requantize32_params* params); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_interface.h b/kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_interface.h new file mode 100644 index 0000000000000000000000000000000000000000..c09dc3e995fb7ed09db8c9f9c44777eb6cf6bc48 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_interface.h @@ -0,0 +1,52 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// All micro-kernels variants of the same type share the same interfaces +// In this case, the micro-kernel type is: matmul_clamp_qai8_qai8p_qsi8cxpsb + +/// Micro-kernel helper functions ("get" methods) +typedef size_t (*kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_get_m_step_func_t)(void); +typedef size_t (*kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_get_n_step_func_t)(void); +typedef size_t (*kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_get_mr_func_t)(void); +typedef size_t (*kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_get_nr_func_t)(void); +typedef size_t (*kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_get_kr_func_t)(void); +typedef size_t (*kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_get_sr_func_t)(void); +typedef size_t (*kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_get_lhs_packed_offset_func_t)(size_t m_idx, size_t k); +typedef size_t (*kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_get_rhs_packed_offset_func_t)(size_t n_idx, size_t k); +typedef size_t (*kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_get_dst_offset_func_t)( + size_t m_idx, size_t n_idx, size_t dst_stride); +typedef size_t (*kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_get_dst_size_func_t)(size_t m, size_t n); + +/// Micro-kernel core function ("run" method) +typedef void (*kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_run_matmul_func_t)( + size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row, + size_t dst_stride_col, float scalar_min, float scalar_max); + +/// Micro-kernel interface +struct kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_ukernel { + kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_get_m_step_func_t get_m_step; + kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_get_n_step_func_t get_n_step; + kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_get_mr_func_t get_mr; + kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_get_nr_func_t get_nr; + kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_get_kr_func_t get_kr; + kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_get_sr_func_t get_sr; + kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_get_lhs_packed_offset_func_t get_lhs_packed_offset; + kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_get_rhs_packed_offset_func_t get_rhs_packed_offset; + kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_get_dst_offset_func_t get_dst_offset; + kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_get_dst_size_func_t get_dst_size; + kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_run_matmul_func_t run_matmul; +}; + +#ifdef __cplusplus +} +#endif diff --git a/kai/ukernels/matmul/pack/README.md b/kai/ukernels/matmul/pack/README.md index 5cb7133c3b5295591edd137545536e61b93d4bb9..950a69ac41ae0809c1a16e6b4fd1528b0f9f6da1 100644 --- a/kai/ukernels/matmul/pack/README.md +++ b/kai/ukernels/matmul/pack/README.md @@ -26,6 +26,20 @@ The pattern of the packed output is shown below Each block has bias and weights arranged as expected by the micro kernel to produce a mr x nr output matrix. There can be padding involved in the blocks depending on the combination of underlying instruction used for the optimization in the micro kernel, the chosen values of mr and nr and input dimensions, M, N and K. +#### kai_run_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme() + +Pack RHS(weights), bias and scaling factor together into X number of blocks that are a combination of scale, bias and RHS. Details of the input are below. + +1. Values calculated using the bias, reduce_sum and lhs_zero point such that; Value\[n\] = Bias\[n\] - (lhs_zero_point * reduce_sum\[n\]). Each block has nr elements, including padding. +1. Non-transposed RHS of dimension KxN. Each block contains nr\*kr elements, including any padding. +1. Scale values calculated as Scale\[n\] = (rhs_scale\[n\] * lhs_scale) / dst_scale. Each block has nr elements, including any padding. + +The pattern of the packed output is shown below. + +![rhs_pack_pattern_2](../../../../docs/imgs/kai_rhs_packing_pattern_2.png)
+ +Padding may be involved in the blocks depending on the values of mr, nr and kr and the input dimensions, M, N and K. + ## Packing for int4 matmul micro-kernels For optimal cache utilization, the operands are packed for the matmul operations. There are 2 types of packing functions used in int4 matmul micro-kernels: diff --git a/kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.c b/kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.c new file mode 100644 index 0000000000000000000000000000000000000000..59fb328f696abb9879a696f551d408b94d7021b2 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.c @@ -0,0 +1,360 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural features check. + +#include "kai_lhs_pack_x8p2vlx4_x8_sme.h" + +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_mr = 2; +static const size_t kai_kr = 4; +static const size_t kai_sr = 1; + +static inline size_t kai_get_m_step(void) { + return (kai_mr * kai_get_sme_vector_length_u8()) / kai_kr; +} + +size_t kai_get_m_step_lhs_pack_x8p2vlx4_x8_sme(size_t mr) { + KAI_ASSUME(mr == kai_mr * kai_get_sme_vector_length_u8() / kai_kr); + KAI_UNUSED(mr); + + return (kai_mr * kai_get_sme_vector_length_u8()) / kai_kr; +} + +size_t kai_get_lhs_offset_lhs_pack_x8p2vlx4_x8_sme(size_t m_idx, size_t lhs_stride) { + KAI_ASSUME(m_idx % (kai_get_m_step()) == 0); + + return m_idx * lhs_stride; +} + +size_t kai_get_lhs_packed_offset_lhs_pack_x8p2vlx4_x8_sme(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) { + const size_t scaled_mr = kai_get_m_step(); + KAI_ASSUME(m_idx % scaled_mr == 0); + KAI_ASSUME(mr == scaled_mr); + KAI_ASSUME(kr == kai_kr); + KAI_ASSUME(sr == kai_sr); + + KAI_UNUSED(mr); + KAI_UNUSED(kr); + KAI_UNUSED(sr); + + return m_idx * kai_roundup(k, kai_kr) * sizeof(int8_t); +} + +size_t kai_get_lhs_packed_size_lhs_pack_x8p2vlx4_x8_sme(size_t m, size_t k, size_t mr, size_t kr, size_t sr) { + KAI_ASSUME(mr == kai_get_m_step()); + KAI_ASSUME(kr == kai_kr); + KAI_ASSUME(sr == kai_sr); + + KAI_UNUSED(mr); + KAI_UNUSED(kr); + KAI_UNUSED(sr); + + return (kai_roundup(m, kai_get_m_step()) * kai_roundup(k, kai_kr) * sizeof(int8_t)); +} + +void kai_run_lhs_pack_x8p2vlx4_x8_sme( + size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride, + void* lhs_packed) { + KAI_ASSUME(mr == kai_get_m_step()); + KAI_ASSUME(kr == kai_kr); + KAI_ASSUME(sr == kai_sr); + KAI_ASSUME(lhs != NULL); + KAI_ASSUME(lhs_packed != NULL); + + KAI_ASSUME(m_idx_start == 0); + + const size_t block_height = kai_get_m_step(); + const size_t width = k; + const size_t row_offset = 0; + + const void* in[block_height]; + const uint8_t* lhs_ptr = lhs; + uint8_t* lhs_packed_ptr = lhs_packed; + + for (size_t block_y = 0; block_y < m; block_y += block_height) { + const size_t height = KAI_MIN(m - block_y, block_height); + void* out = lhs_packed_ptr + block_y * kai_roundup(k, kai_kr) * sizeof(int8_t); + + for (size_t y = 0; y < height; y++) { + in[y] = lhs_ptr + (block_y + y) * lhs_stride; + } + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "mov x23, %x[width]\n" + "mov x21, %x[width]\n" + "cntb x20\n" + "incb x23\n" + "sub x7, x20, #0x1\n" + "cntw x8\n" + "sub x23, x23, #0x1\n" + "ands x7, x21, x7\n" + "udiv x23, x23, x20\n" // n_passes = ceildiv(width, VL) + "csel x7, x7, x20, NE\n" + "lsl x22, %x[height], #0x1\n" // height * 2 + "lsl x21, x8, #0x1\n" + "sub x20, x23, #0x1\n" + "add x7, x7, #0x3\n" + "sub x17, x8, #0x2\n" + "whilelt p9.b, XZR, x22\n" + "whilelt p8.b, x21, x22\n" + "mov x16, #0x0\n" + "mov x11, %x[in]\n" + "add x10, %x[in], x8, LSL #3\n" + "cntw x9, ALL, MUL #2\n" + "cntw x28, ALL, MUL #3\n" + "ldr x27, [x11, #0x0]\n" + "lsr x20, x20, #0x1\n" // n_loops = (n_passes - 1) / 2 + "and x26, x23, #0x1\n" // odd_tail = bool(n_passes & 0x1) + "ldr x25, [x10, #0x0]\n" + "lsr x7, x7, #0x2\n" + "ptrue p11.s\n" + "ldr x24, [x11, #0x8]\n" + "zip1 p10.b, p9.b, p8.b\n" + "mov x23, %x[row_offset]\n" + "ldr x21, [x10, #0x8]\n" + "mov x22, %x[out]\n" + "whilelt p9.b, x16, %x[width]\n" + "whilelt p8.b, x16, %x[width]\n" + "add x11, x11, #0x10\n" + "add x10, x10, #0x10\n" + "mov x12, #0x0\n" + "cbz x17, 2f\n" + "1:" // K loop: Charge: Loop + ".inst 0x25246143 // psel p3.b, p8.b/Z, p10.b[w12]\n" + ".inst 0x252c6142 // psel p2.b, p8.b/Z, p10.b[w12, #1]\n" + ".inst 0x25646141 // psel p1.b, p8.b/Z, p10.b[w12, #4]\n" + ".inst 0x256c6140 // psel p0.b, p8.b/Z, p10.b[w12, #5]\n" + ".inst 0xe0170f60 // ld1b { za0h.b[x12] }, p3/Z, [x27, x23]\n" + "ldr x27, [x11, #0x0]\n" + ".inst 0xe0170b21 // ld1b { za0h.b[x12, #1] }, p2/Z, [x25, x23]\n" + "ldr x25, [x10, #0x0]\n" + ".inst 0xe0170704 // ld1b { za0h.b[x12, #4] }, p1/Z, [x24, x23]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe01702a5 // ld1b { za0h.b[x12, #5] }, p0/Z, [x21, x23]\n" + "add x12, x12, #0x8\n" + "ldr x21, [x10, #0x8]\n" + "add x10, x10, #0x10\n" + "cmp x12, x17, LSL #2\n" + "blt 1b\n" + "2:" // K loop: Charge: End + ".inst 0x25246143 // psel p3.b, p8.b/Z, p10.b[w12]\n" + ".inst 0x252c6142 // psel p2.b, p8.b/Z, p10.b[w12, #1]\n" + ".inst 0x25646141 // psel p1.b, p8.b/Z, p10.b[w12, #4]\n" + ".inst 0x256c6140 // psel p0.b, p8.b/Z, p10.b[w12, #5]\n" + "mov x11, %x[in]\n" + "add x10, %x[in], x8, LSL #3\n" + ".inst 0xe0170f60 // ld1b { za0h.b[x12] }, p3/Z, [x27, x23]\n" + "ldr x27, [x11, #0x0]\n" + "incb x16\n" + ".inst 0xe0170b21 // ld1b { za0h.b[x12, #1] }, p2/Z, [x25, x23]\n" + "ldr x25, [x10, #0x0]\n" + ".inst 0xe0170704 // ld1b { za0h.b[x12, #4] }, p1/Z, [x24, x23]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe01702a5 // ld1b { za0h.b[x12, #5] }, p0/Z, [x21, x23]\n" + "ldr x21, [x10, #0x8]\n" + "add x10, x10, #0x10\n" + "incb x23\n" + "cbz x20, 8f\n" + "mov x20, x20\n" + "3:" // K loop: Main loop + "whilelt p8.b, x16, %x[width]\n" + "mov x15, #0x0\n" + "mov x14, #0x0\n" + "cbz x17, 5f\n" + "4:" // K loop: Main loop: First: Loop + ".inst 0x25376143 // psel p3.b, p8.b/Z, p10.b[w15, #2]\n" + ".inst 0x253f6142 // psel p2.b, p8.b/Z, p10.b[w15, #3]\n" + ".inst 0x25776141 // psel p1.b, p8.b/Z, p10.b[w15, #6]\n" + ".inst 0x257f6140 // psel p0.b, p8.b/Z, p10.b[w15, #7]\n" + ".inst 0xe0176f62 // ld1b { za0h.b[x15, #2] }, p3/Z, [x27, x23]\n" + ".inst 0x25266d23 // psel p3.b, p11.b/Z, p9.b[w14]\n" + "ldr x27, [x11, #0x0]\n" + ".inst 0xe0176b23 // ld1b { za0h.b[x15, #3] }, p2/Z, [x25, x23]\n" + ".inst 0x25266d22 // psel p2.b, p11.b/Z, p9.b[w14]\n" + "ldr x25, [x10, #0x0]\n" + ".inst 0xe0176706 // ld1b { za0h.b[x15, #6] }, p1/Z, [x24, x23]\n" + ".inst 0x252e6d21 // psel p1.b, p11.b/Z, p9.b[w14, #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe01762a7 // ld1b { za0h.b[x15, #7] }, p0/Z, [x21, x23]\n" + "ldr x21, [x10, #0x8]\n" + ".inst 0x252e6d20 // psel p0.b, p11.b/Z, p9.b[w14, #1]\n" + "add x10, x10, #0x10\n" + ".inst 0xe0bfcec0 // st1w { za0v.s[x14] }, p3/Z, [x22, XZR, LSL #2]\n" + "add x15, x15, #0x8\n" + ".inst 0xe0a8cac4 // st1w { za1v.s[x14] }, p2/Z, [x22, x8, LSL #2]\n" + ".inst 0xe0a9c6c1 // st1w { za0v.s[x14, #1] }, p1/Z, [x22, x9, LSL #2]\n" + ".inst 0xe0bcc2c5 // st1w { za1v.s[x14, #1] }, p0/Z, [x22, x28, LSL #2]\n" + "add x14, x14, #0x2\n" + "addvl x22, x22, #4\n" + "cmp x14, x17\n" + "blt 4b\n" + "5:" // K loop: Main loop: First: Tail + ".inst 0x25376143 // psel p3.b, p8.b/Z, p10.b[w15, #2]\n" + ".inst 0x253f6142 // psel p2.b, p8.b/Z, p10.b[w15, #3]\n" + ".inst 0x25776141 // psel p1.b, p8.b/Z, p10.b[w15, #6]\n" + ".inst 0x257f6140 // psel p0.b, p8.b/Z, p10.b[w15, #7]\n" + "mov x11, %x[in]\n" + "add x10, %x[in], x8, LSL #3\n" + ".inst 0xe0176f62 // ld1b { za0h.b[x15, #2] }, p3/Z, [x27, x23]\n" + ".inst 0x25266d23 // psel p3.b, p11.b/Z, p9.b[w14]\n" + "ldr x27, [x11, #0x0]\n" + "mov x13, #0x0\n" + ".inst 0xe0176b23 // ld1b { za0h.b[x15, #3] }, p2/Z, [x25, x23]\n" + ".inst 0x25266d22 // psel p2.b, p11.b/Z, p9.b[w14]\n" + "ldr x25, [x10, #0x0]\n" + "mov x12, #0x0\n" + ".inst 0xe0176706 // ld1b { za0h.b[x15, #6] }, p1/Z, [x24, x23]\n" + ".inst 0x252e6d21 // psel p1.b, p11.b/Z, p9.b[w14, #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe01762a7 // ld1b { za0h.b[x15, #7] }, p0/Z, [x21, x23]\n" + "ldr x21, [x10, #0x8]\n" + ".inst 0x252e6d20 // psel p0.b, p11.b/Z, p9.b[w14, #1]\n" + "whilelt p9.b, x16, %x[width]\n" + ".inst 0xe0bfcec0 // st1w { za0v.s[x14] }, p3/Z, [x22, XZR, LSL #2]\n" + "incb x16\n" + "add x10, x10, #0x10\n" + ".inst 0xe0a8cac4 // st1w { za1v.s[x14] }, p2/Z, [x22, x8, LSL #2]\n" + "incb x23\n" + "whilelt p8.b, x16, %x[width]\n" + ".inst 0xe0a9c6c1 // st1w { za0v.s[x14, #1] }, p1/Z, [x22, x9, LSL #2]\n" + ".inst 0xe0bcc2c5 // st1w { za1v.s[x14, #1] }, p0/Z, [x22, x28, LSL #2]\n" + "addvl x22, x22, #4\n" + "cbz x17, 7f\n" + "6:" // K loop: Main loop: Second: Loop + ".inst 0x25256143 // psel p3.b, p8.b/Z, p10.b[w13]\n" + ".inst 0x252d6142 // psel p2.b, p8.b/Z, p10.b[w13, #1]\n" + ".inst 0x25656141 // psel p1.b, p8.b/Z, p10.b[w13, #4]\n" + ".inst 0x256d6140 // psel p0.b, p8.b/Z, p10.b[w13, #5]\n" + ".inst 0xe0172f60 // ld1b { za0h.b[x13] }, p3/Z, [x27, x23]\n" + ".inst 0x25246d23 // psel p3.b, p11.b/Z, p9.b[w12]\n" + "ldr x27, [x11, #0x0]\n" + ".inst 0xe0172b21 // ld1b { za0h.b[x13, #1] }, p2/Z, [x25, x23]\n" + ".inst 0x25246d22 // psel p2.b, p11.b/Z, p9.b[w12]\n" + "ldr x25, [x10, #0x0]\n" + ".inst 0xe0172704 // ld1b { za0h.b[x13, #4] }, p1/Z, [x24, x23]\n" + ".inst 0x252c6d21 // psel p1.b, p11.b/Z, p9.b[w12, #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe01722a5 // ld1b { za0h.b[x13, #5] }, p0/Z, [x21, x23]\n" + "ldr x21, [x10, #0x8]\n" + ".inst 0x252c6d20 // psel p0.b, p11.b/Z, p9.b[w12, #1]\n" + "add x10, x10, #0x10\n" + ".inst 0xe0bf8ec8 // st1w { za2v.s[x12] }, p3/Z, [x22, XZR, LSL #2]\n" + "add x13, x13, #0x8\n" + ".inst 0xe0a88acc // st1w { za3v.s[x12] }, p2/Z, [x22, x8, LSL #2]\n" + ".inst 0xe0a986c9 // st1w { za2v.s[x12, #1] }, p1/Z, [x22, x9, LSL #2]\n" + ".inst 0xe0bc82cd // st1w { za3v.s[x12, #1] }, p0/Z, [x22, x28, LSL #2]\n" + "add x12, x12, #0x2\n" + "addvl x22, x22, #4\n" + "cmp x12, x17\n" + "blt 6b\n" + "7:" // K loop: Main loop: Second: Tail + ".inst 0x25256143 // psel p3.b, p8.b/Z, p10.b[w13]\n" + ".inst 0x252d6142 // psel p2.b, p8.b/Z, p10.b[w13, #1]\n" + ".inst 0x25656141 // psel p1.b, p8.b/Z, p10.b[w13, #4]\n" + ".inst 0x256d6140 // psel p0.b, p8.b/Z, p10.b[w13, #5]\n" + "mov x11, %x[in]\n" + "add x10, %x[in], x8, LSL #3\n" + ".inst 0xe0172f60 // ld1b { za0h.b[x13] }, p3/Z, [x27, x23]\n" + ".inst 0x25246d23 // psel p3.b, p11.b/Z, p9.b[w12]\n" + "ldr x27, [x11, #0x0]\n" + ".inst 0xe0172b21 // ld1b { za0h.b[x13, #1] }, p2/Z, [x25, x23]\n" + ".inst 0x25246d22 // psel p2.b, p11.b/Z, p9.b[w12]\n" + "ldr x25, [x10, #0x0]\n" + ".inst 0xe0172704 // ld1b { za0h.b[x13, #4] }, p1/Z, [x24, x23]\n" + ".inst 0x252c6d21 // psel p1.b, p11.b/Z, p9.b[w12, #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe01722a5 // ld1b { za0h.b[x13, #5] }, p0/Z, [x21, x23]\n" + "ldr x21, [x10, #0x8]\n" + ".inst 0x252c6d20 // psel p0.b, p11.b/Z, p9.b[w12, #1]\n" + "whilelt p9.b, x16, %x[width]\n" + ".inst 0xe0bf8ec8 // st1w { za2v.s[x12] }, p3/Z, [x22, XZR, LSL #2]\n" + "subs x20, x20, #0x1\n" + "add x10, x10, #0x10\n" + ".inst 0xe0a88acc // st1w { za3v.s[x12] }, p2/Z, [x22, x8, LSL #2]\n" + "incb x16\n" + "incb x23\n" + ".inst 0xe0a986c9 // st1w { za2v.s[x12, #1] }, p1/Z, [x22, x9, LSL #2]\n" + ".inst 0xe0bc82cd // st1w { za3v.s[x12, #1] }, p0/Z, [x22, x28, LSL #2]\n" + "addvl x22, x22, #4\n" + "bgt 3b\n" + "8:" // K loop: Tails + "cbnz x26, 11f\n" + "mov x11, %x[in]\n" + "whilelt p8.b, x16, %x[width]\n" + "mov x13, #0x0\n" + "mov x12, #0x0\n" + "9:" // K loop: Tails: Even: First + ".inst 0x25306d23 // psel p3.s, p11.s/Z, p9.s[w12]\n" + ".inst 0x25306d22 // psel p2.s, p11.s/Z, p9.s[w12]\n" + ".inst 0x25356141 // psel p1.b, p8.b/Z, p10.b[w13, #2]\n" + ".inst 0x253d6140 // psel p0.b, p8.b/Z, p10.b[w13, #3]\n" + ".inst 0xe0bf8ec0 // st1w { za0v.s[x12] }, p3/Z, [x22, XZR, LSL #2]\n" + ".inst 0xe0a88ac4 // st1w { za1v.s[x12] }, p2/Z, [x22, x8, LSL #2]\n" + "add x12, x12, #0x1\n" + "addvl x22, x22, #2\n" + "ldr x21, [x11, #0x0]\n" + "cmp x12, x8\n" + "ldr x20, [x11, x8, LSL #0x3]\n" + "add x11, x11, #0x8\n" + ".inst 0xe01726a2 // ld1b { za0h.b[x13, #2] }, p1/Z, [x21, x23]\n" + ".inst 0xe0172283 // ld1b { za0h.b[x13, #3] }, p0/Z, [x20, x23]\n" + "add x13, x13, #0x4\n" + "blt 9b\n" + "whilelt p9.b, x16, %x[width]\n" + "whilelt p8.b, x16, %x[width]\n" + "mov x20, #0x0\n" + "mov x12, #0x0\n" + "10:" // K loop: Tails: Even: Second + ".inst 0x25306d21 // psel p1.s, p11.s/Z, p9.s[w12]\n" + ".inst 0x25306d20 // psel p0.s, p11.s/Z, p9.s[w12]\n" + "add x20, x20, #0x4\n" + ".inst 0xe0bf86c8 // st1w { za2v.s[x12] }, p1/Z, [x22, XZR, LSL #2]\n" + ".inst 0xe0a882cc // st1w { za3v.s[x12] }, p0/Z, [x22, x8, LSL #2]\n" + "add x12, x12, #0x1\n" + "addvl x22, x22, #2\n" + "cmp x12, x7\n" + "blt 10b\n" + "whilelt p8.b, x16, %x[width]\n" + "b 13f\n" + "11:" // K loop: Tails: Odd + "mov x12, #0x0\n" + "12:" // K loop: Tails: Odd: Loop + ".inst 0x25306d21 // psel p1.s, p11.s/Z, p9.s[w12]\n" + ".inst 0x25306d20 // psel p0.s, p11.s/Z, p9.s[w12]\n" + ".inst 0xe0bf86c0 // st1w { za0v.s[x12] }, p1/Z, [x22, XZR, LSL #2]\n" + ".inst 0xe0a882c4 // st1w { za1v.s[x12] }, p0/Z, [x22, x8, LSL #2]\n" + "add x12, x12, #0x1\n" + "addvl x22, x22, #2\n" + "cmp x12, x7\n" + "blt 12b\n" + "13:" // K loop: End + "mov %x[out], x22\n" + ".inst 0xd503467f // SMSTOP\n" + : [out] "+&r"(out) + : [height] "r"(height), [in] "r"(in), [row_offset] "r"(row_offset), [width] "r"(width) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", + "p14", "p15", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x20", "x21", + "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", + "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", + "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"); + } +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.h b/kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.h new file mode 100644 index 0000000000000000000000000000000000000000..3b95a0a9aa76e64e1f4c9bd6b485b1241dce7e83 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.h @@ -0,0 +1,77 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Gets m step value. +/// +/// The starting row index must be divisible by `m_step`. +/// +/// @param[in] mr Number of rows to be interleaved. +/// +/// @return The m step value. +size_t kai_get_m_step_lhs_pack_x8p2vlx4_x8_sme(size_t mr); + +/// Gets the offset in bytes to the data element in the LHS buffer. +/// +/// @param[in] m_idx Row index. +/// @param[in] lhs_stride Row stride in bytes. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_lhs_offset_lhs_pack_x8p2vlx4_x8_sme(size_t m_idx, size_t lhs_stride); + +/// Gets the offset in bytes to the data element in the packed LHS buffer. +/// +/// @param[in] m_idx Row index in the unpacked LHS matrix. +/// @param[in] k Number of columns in the unpacked LHS matrix. +/// @param[in] mr Number of rows to be interleaved. +/// @param[in] kr Unused. Must be 1. +/// @param[in] sr Unused. Must be 1. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_lhs_packed_offset_lhs_pack_x8p2vlx4_x8_sme(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr); + +/// Gets the size in bytes of the packed LHS buffer. +/// +/// @param[in] m Number of rows in the unpacked LHS matrix. +/// @param[in] k Number of columns in the unpacked LHS matrix. +/// @param[in] mr Number of rows to be interleaved. +/// @param[in] kr Unused. Must be 1. +/// @param[in] sr Unused. Must be 1. +/// +/// @return The size in bytes of the packed LHS buffer. +size_t kai_get_lhs_packed_size_lhs_pack_x8p2vlx4_x8_sme(size_t m, size_t k, size_t mr, size_t kr, size_t sr); + +/// Runs the LHS packing function for matrix multiplication. +/// +/// The pointer of each buffers (LHS and packed LHS) needs to be added with offset +/// calculated using the following functions: +/// +/// * LHS: @ref kai_get_lhs_offset_lhs_pack_x8p2vlx4_x8_sme. +/// * Packed LHS: @ref kai_get_lhs_packed_offset_lhs_pack_x8p2vlx4_x8_sme. +/// +/// @param[in] m Number of rows of the unpacked LHS matrix. +/// @param[in] k Common dimension between the LHS and RHS matrix. +/// @param[in] mr Block size in M dimension. It must be 2 * kai_get_sme_vector_length_u8(). +/// @param[in] kr Block size in K dimension. It must be 4. +/// @param[in] sr Number of kr splits. It must be 1. +/// @param[in] m_idx_start Unused. Must be 0. +/// @param[in] lhs LHS matrix data buffer. +/// @param[in] lhs_stride Row stride in bytes of the LHS matrix. +/// @param[out] lhs_packed Packed RHS matrix. +void kai_run_lhs_pack_x8p2vlx4_x8_sme( + size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride, + void* lhs_packed); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c new file mode 100644 index 0000000000000000000000000000000000000000..c50ea6250bc5bf83759e08e569f504ff3eecdcaf --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c @@ -0,0 +1,262 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural features check. + +#include "kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h" + +#include +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_nr = 2; +static const size_t kai_kr = 4; +static const size_t kai_num_bytes_input = 1; +static const size_t kai_num_bytes_output = 1; +static const size_t kai_num_bytes_bias = 4; +static const size_t kai_num_bytes_scale = 4; + +size_t kai_get_n_step_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(void) { + return kai_nr * kai_get_sme_vector_length_u8() / kai_kr; +} + +size_t kai_get_rhs_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n_idx) { + KAI_ASSUME(n_idx % (kai_nr * kai_get_sme_vector_length_u8() / kai_kr) == 0); + + return n_idx * kai_num_bytes_input; +} + +size_t kai_get_bias_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n_idx) { + return n_idx * kai_num_bytes_bias; +} + +size_t kai_get_scale_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n_idx) { + return n_idx * kai_num_bytes_scale; +} + +size_t kai_get_rhs_packed_stride_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t k) { + return kai_get_n_step_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme() * + (kai_num_bytes_bias + kai_roundup(k, kai_kr) * kai_num_bytes_output + kai_num_bytes_scale); +} + +size_t kai_get_rhs_packed_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n_idx, size_t k) { + KAI_ASSUME(n_idx % kai_get_n_step_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme() == 0); + + return (n_idx / kai_get_n_step_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme()) * + kai_get_rhs_packed_stride_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(k); +} + +size_t kai_get_rhs_packed_size_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n, size_t k) { + return kai_get_rhs_packed_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( + kai_roundup(n, kai_nr * kai_get_sme_vector_length_u8() / kai_kr), k); +} + +void kai_run_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, + const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, + const struct kai_rhs_pack_qsi8_params* params) { + KAI_ASSUME(num_groups == 1); + KAI_ASSUME(nr == kai_nr * kai_get_sme_vector_length_u8() / kai_kr); + KAI_ASSUME(kr == kai_kr); + KAI_ASSUME(sr == 1); + KAI_ASSUME(rhs != NULL); + KAI_ASSUME(bias != NULL); + KAI_ASSUME(rhs_packed != NULL); + KAI_ASSUME(extra_bytes == 0); + KAI_ASSUME(params != NULL); + + size_t height = k; + const size_t width = n; + const void* in = rhs; + void* out = rhs_packed; + const size_t in_stride = rhs_stride; + uint8_t pad_row[nr]; + + if (height % 4) { + memset(pad_row, 0, nr * sizeof(uint8_t)); + } + + size_t out_stride = kai_get_rhs_packed_stride_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(height); + const int32_t lhs_zero_point = params->lhs_zero_point; + const float scale_multiplier = params->scale_multiplier; + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "cmp %x[height], #0x8\n" + "mov x11, %x[out]\n" + "ptrue p2.b\n" + "mov x10, %x[height]\n" + "incb %x[out], ALL, MUL #2\n" + "blt 4f\n" + "1:" // Main row loop: Head + "mov x9, %x[in]\n" + "mov x28, %x[out]\n" + "add x27, x9, %x[in_stride]\n" + "sub %x[height], %x[height], #0x8\n" + "add x26, x27, %x[in_stride]\n" + "mov x24, %x[width]\n" + "add x25, x26, %x[in_stride]\n" + "add x23, x25, %x[in_stride]\n" + "add x22, x23, %x[in_stride]\n" + "add x21, x22, %x[in_stride]\n" + "add x20, x21, %x[in_stride]\n" + "add %x[in], x20, %x[in_stride]\n" + "2:" // Main row loop: Column loop + "whilelt p0.b, XZR, x24\n" + "decw x24, ALL, MUL #2\n" + "ld1b { z18.b }, p0/Z, [x9]\n" + "cmp x24, #0x0\n" + "incd x9, ALL, MUL #4\n" + "ld1b { z22.b }, p0/Z, [x27]\n" + "incd x27, ALL, MUL #4\n" + "ld1b { z17.b }, p0/Z, [x26]\n" + "incd x26, ALL, MUL #4\n" + "ld1b { z16.b }, p0/Z, [x25]\n" + "incd x25, ALL, MUL #4\n" + "ld1b { z20.b }, p0/Z, [x23]\n" + "incd x23, ALL, MUL #4\n" + "ld1b { z19.b }, p0/Z, [x22]\n" + "zip1 z21.b, z18.b, z17.b\n" + "incd x22, ALL, MUL #4\n" + "ld1b { z18.b }, p0/Z, [x21]\n" + "zip1 z17.b, z22.b, z16.b\n" + "incd x21, ALL, MUL #4\n" + "ld1b { z16.b }, p0/Z, [x20]\n" + "incd x20, ALL, MUL #4\n" + "zip1 z20.b, z20.b, z18.b\n" + "zip1 z16.b, z19.b, z16.b\n" + "zip1 z19.b, z21.b, z17.b\n" + "zip2 z18.b, z21.b, z17.b\n" + "zip1 z17.b, z20.b, z16.b\n" + "zip2 z16.b, z20.b, z16.b\n" + "st1b { z19.b }, p2, [x28]\n" + "st1b { z18.b }, p2, [x28, #1, MUL VL]\n" + "st1b { z17.b }, p2, [x28, #2, MUL VL]\n" + "st1b { z16.b }, p2, [x28, #3, MUL VL]\n" + "add x28, x28, %x[out_stride]\n" + "bgt 2b\n" + "cmp %x[height], #0x8\n" + "addvl %x[out], %x[out], #4\n" + "bge 1b\n" + "cbz %x[height], 8f\n" + "4:" // Main loop skip + "5:" // Tail row loop: Head + "mov x9, %x[in]\n" + "cntw x24, ALL, MUL #2\n" + "add x27, x9, %x[in_stride]\n" + "cmp %x[height], #0x3\n" + "add x26, x27, %x[in_stride]\n" + "csel x23, x24, XZR, GT\n" + "add x25, x26, %x[in_stride]\n" + "csel x26, x26, %x[pad_row], GE\n" + "add %x[in], x25, %x[in_stride]\n" + "csel x25, x25, %x[pad_row], GT\n" + "csel x22, x24, XZR, GE\n" + "cmp %x[height], #0x1\n" + "mov x28, %x[out]\n" + "csel x27, x27, %x[pad_row], GT\n" + "csel x21, x24, XZR, GT\n" + "sub %x[height], %x[height], #0x4\n" + "mov x20, %x[width]\n" + "6:" // Tail row loop: Column loop + "whilelt p0.b, XZR, x20\n" + "decw x20, ALL, MUL #2\n" + "ld1b { z18.b }, p0/Z, [x9]\n" + "cmp x20, #0x0\n" + "add x9, x9, x24\n" + "ld1b { z19.b }, p0/Z, [x27]\n" + "add x27, x27, x21\n" + "ld1b { z17.b }, p0/Z, [x26]\n" + "add x26, x26, x22\n" + "ld1b { z16.b }, p0/Z, [x25]\n" + "add x25, x25, x23\n" + "zip1 z18.b, z18.b, z17.b\n" + "zip1 z16.b, z19.b, z16.b\n" + "zip1 z17.b, z18.b, z16.b\n" + "zip2 z16.b, z18.b, z16.b\n" + "st1b { z17.b }, p2, [x28]\n" + "st1b { z16.b }, p2, [x28, #1, MUL VL]\n" + "add x28, x28, %x[out_stride]\n" + "bgt 6b\n" + "cmp %x[height], #0x1\n" + "addvl %x[out], %x[out], #2\n" + "bge 5b\n" + "8:" // Done + "mov x22, %x[out]\n" + "mov x21, %x[width]\n" + "dup z18.s, %w[scale_multiplier]\n" + "cbz %x[scale], 10f\n" + "9:" // Scale: Full loop + "mov x20, x21\n" + "decw x21, ALL, MUL #2\n" + "whilelt p1.s, XZR, x20\n" + "decw x20\n" + "whilelt p0.s, XZR, x20\n" + "ld1w { z17.s }, p1/Z, [%x[scale]]\n" + "cmp x21, #0x0\n" + "ld1w { z16.s }, p0/Z, [%x[scale], #1, MUL VL]\n" + "incb %x[scale], ALL, MUL #2\n" + "fmul z17.s, z17.s, z18.s\n" + "fmul z16.s, z16.s, z18.s\n" + "st1w { z17.s }, p2, [x22]\n" + "st1w { z16.s }, p2, [x22, #1, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bgt 9b\n" + "10:" // Scale: Done + "cbz %x[width], 13f\n" + "cbz x10, 13f\n" + "dup z21.s, %w[lhs_zero_point]\n" + "add x25, x10, #0x3\n" + "cntw x24, ALL, MUL #2\n" + "mov z20.b, #0x1\n" + "lsr x25, x25, #0x2\n" + "mov x23, %x[width]\n" + "addvl x22, x11, #2\n" + "neg z21.s, p2/M, z21.s\n" + "11:" // Bias: N loop + "mov x21, x22\n" + "mov x20, x25\n" + "mov z19.s, #0x0\n" + "mov z18.s, #0x0\n" + "12:" // Bias: K loop + "ld1b { z17.b }, p2/Z, [x21]\n" + "subs x20, x20, #0x1\n" + "ld1b { z16.b }, p2/Z, [x21, #1, MUL VL]\n" + "addvl x21, x21, #2\n" + "sdot z19.s, z17.b, z20.b\n" + "sdot z18.s, z16.b, z20.b\n" + "bgt 12b\n" + "mov x20, x23\n" + "add x22, x22, %x[out_stride]\n" + "whilelt p1.s, XZR, x20\n" + "decw x20\n" + "whilelt p0.s, XZR, x20\n" + "ld1w { z17.s }, p1/Z, [%x[bias]]\n" + "subs x23, x23, x24\n" + "ld1w { z16.s }, p0/Z, [%x[bias], #1, MUL VL]\n" + "addvl %x[bias], %x[bias], #2\n" + "mla z17.s, p2/M, z19.s, z21.s\n" + "mla z16.s, p2/M, z18.s, z21.s\n" + "st1w { z17.s }, p2, [x11]\n" + "st1w { z16.s }, p2, [x11, #1, MUL VL]\n" + "add x11, x11, %x[out_stride]\n" + "bgt 11b\n" + "13:" // Bias: Done + ".inst 0xd503467f // SMSTOP\n" + : [bias] "+&r"(bias), [height] "+&r"(height), [in] "+&r"(in), [out] "+&r"(out), [scale] "+&r"(scale) + : [in_stride] "r"(in_stride), [lhs_zero_point] "r"(lhs_zero_point), [out_stride] "r"(out_stride), + [pad_row] "r"(pad_row), [scale_multiplier] "r"(scale_multiplier), [width] "r"(width) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", + "p15", "x9", "x10", "x11", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", + "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", + "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h new file mode 100644 index 0000000000000000000000000000000000000000..effb8da9b9926b582f8e7a864cd195d8714dd5ac --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h @@ -0,0 +1,91 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#include "kai/kai_common.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Gets n step value. +/// +/// The starting row index must be divisible by `n_step`. +/// +/// @return The n step value. +size_t kai_get_n_step_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(void); + +/// Gets the offset in bytes to the data element in the RHS matrix buffer. +/// +/// @param[in] n_idx Column index. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n_idx); + +/// Gets the offset in bytes to the data element in the bias buffer. +/// +/// @param[in] n_idx Column index. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_bias_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n_idx); + +/// Gets the offset in bytes to the data element in the scale buffer. +/// +/// @param[in] n_idx Column index. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_scale_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n_idx); + +/// Gets the offset in bytes to the data element in the packed RHS buffer. +/// +/// @param[in] n_idx Column index. +/// @param[in] k Number of rows. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_packed_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n_idx, size_t k); + +/// Gets the size in bytes of the packed RHS buffer. +/// +/// @param[in] n Number of columns. +/// @param[in] k Number of rows. +/// +/// @return The size in bytes of the packed RHS buffer. +size_t kai_get_rhs_packed_size_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n, size_t k); + +/// Runs the RHS packing function for matrix multiplication. +/// +/// The pointer of each buffers (RHS, bias and packed RHS) needs to be added with offset +/// calculated using the following functions: +/// +/// * RHS: @ref kai_get_rhs_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(. +/// * Bias: @ref kai_get_bias_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(. +/// * Scale: @ref kai_get_scale_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(. +/// * Output: @ref kai_get_rhs_packed_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(. +/// +/// @param[in] num_groups Number of groups. It must be 1. +/// @param[in] n Number of columns of the output matrix. +/// @param[in] k Common dimension between the LHS and RHS matrix. +/// @param[in] nr Block size in N dimension. It must be 2 * kai_get_sme_vector_length_u8(). +/// @param[in] kr Block size in K dimension. It must be 4. +/// @param[in] sr Number of kr splits. It must be 1. +/// @param[in] rhs_stride Row stride in bytes of the RHS matrix. +/// @param[in] rhs RHS matrix data buffer. +/// @param[in] bias Bias matrix data buffer. +/// @param[in] scale Scale data buffer. It must be NULL. +/// @param[out] rhs_packed Packed RHS matrix. +/// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. It must be 0. +/// @param[in] params Extra packing parameters. It must be NULL. +void kai_run_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, + const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, + const struct kai_rhs_pack_qsi8_params* params); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/test/reference/binary_elementwise.cpp b/test/reference/binary_elementwise.cpp index d5abd26cc18422ee78334b0a63d3f9f9a907c076..f65136572c6689a919eea05359e3e539a4f9a2ed 100644 --- a/test/reference/binary_elementwise.cpp +++ b/test/reference/binary_elementwise.cpp @@ -136,6 +136,18 @@ std::vector sub( lhs, lhs_dt, lhs_height, lhs_width, rhs, rhs_dt, rhs_height, rhs_width); } +template +std::vector sub( + const void* lhs, size_t lhs_height, size_t lhs_width, // + const void* rhs, size_t rhs_height, size_t rhs_width) { + return binary_elementwise_any_op_type( + lhs, rhs, lhs_height, lhs_width, rhs_height, rhs_width); +} + +template std::vector sub( + const void* lhs, size_t lhs_height, size_t lhs_width, // + const void* rhs, size_t rhs_height, size_t rhs_width); + std::vector mul( const void* lhs, DataType lhs_dt, size_t lhs_height, size_t lhs_width, // const void* rhs, DataType rhs_dt, size_t rhs_height, size_t rhs_width) { @@ -143,6 +155,22 @@ std::vector mul( lhs, lhs_dt, lhs_height, lhs_width, rhs, rhs_dt, rhs_height, rhs_width); } +template +std::vector mul( + const void* lhs, size_t lhs_height, size_t lhs_width, // + const void* rhs, size_t rhs_height, size_t rhs_width) { + return binary_elementwise_any_op_type( + lhs, rhs, lhs_height, lhs_width, rhs_height, rhs_width); +} + +template std::vector mul( + const void* lhs, size_t lhs_height, size_t lhs_width, // + const void* rhs, size_t rhs_height, size_t rhs_width); + +template std::vector mul( + const void* lhs, size_t lhs_height, size_t lhs_width, // + const void* rhs, size_t rhs_height, size_t rhs_width); + std::vector div( const void* lhs, DataType lhs_dt, size_t lhs_height, size_t lhs_width, // const void* rhs, DataType rhs_dt, size_t rhs_height, size_t rhs_width) { diff --git a/test/reference/binary_elementwise.hpp b/test/reference/binary_elementwise.hpp index e3d3a9e1fdc36c7846859ebced2e2a09ae8938a5..f2f5c0ab5522748433ecee11a2a4b6552fc834df 100644 --- a/test/reference/binary_elementwise.hpp +++ b/test/reference/binary_elementwise.hpp @@ -50,6 +50,25 @@ std::vector sub( const void* lhs, DataType lhs_dt, size_t lhs_height, size_t lhs_width, // const void* rhs, DataType rhs_dt, size_t rhs_height, size_t rhs_width); +/// Elementwise subtraction. +/// +/// Broadcasting is supported for any dimension and both LHS and RHS operands. +/// +/// @tparam T The data type. +/// +/// @param[in] lhs The LHS data buffer. +/// @param[in] lhs_height The number of rows of the LHS matrix. +/// @param[in] lhs_width The number of columns of the LHS matrix. +/// @param[in] rhs The RHS data buffer. +/// @param[in] rhs_height The number of rows of the RHS matrix. +/// @param[in] rhs_width The number of columns of the LHS matrix. +/// +/// @return The result matrix. +template +std::vector sub( + const void* lhs, size_t lhs_height, size_t lhs_width, // + const void* rhs, size_t rhs_height, size_t rhs_width); + /// Elementwise multiplication. /// /// Broadcasting is supported for any dimension and both LHS and RHS operands. @@ -68,6 +87,25 @@ std::vector mul( const void* lhs, DataType lhs_dt, size_t lhs_height, size_t lhs_width, // const void* rhs, DataType rhs_dt, size_t rhs_height, size_t rhs_width); +/// Elementwise multiplication. +/// +/// Broadcasting is supported for any dimension and both LHS and RHS operands. +/// +/// @tparam T The data type. +/// +/// @param[in] lhs The LHS data buffer. +/// @param[in] lhs_height The number of rows of the LHS matrix. +/// @param[in] lhs_width The number of columns of the LHS matrix. +/// @param[in] rhs The RHS data buffer. +/// @param[in] rhs_height The number of rows of the RHS matrix. +/// @param[in] rhs_width The number of columns of the LHS matrix. +/// +/// @return The result matrix. +template +std::vector mul( + const void* lhs, size_t lhs_height, size_t lhs_width, // + const void* rhs, size_t rhs_height, size_t rhs_width); + /// Elementwise division. /// /// Broadcasting is supported for any dimension and both LHS and RHS operands. diff --git a/test/reference/clamp.cpp b/test/reference/clamp.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a4ba773cc432ce016c554b3f670c30c52bf2b52f --- /dev/null +++ b/test/reference/clamp.cpp @@ -0,0 +1,32 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/reference/clamp.hpp" + +#include +#include +#include +#include + +#include "test/common/memory.hpp" +#include "test/common/round.hpp" + +namespace kai::test { + +template +std::vector clamp(const void* src, size_t len, T min_value, T max_value) { + std::vector dst(round_up_division(len * size_in_bits, 8)); + + for (size_t i = 0; i < len; ++i) { + write_array(dst.data(), i, std::clamp(read_array(src, i), min_value, max_value)); + } + + return dst; +} + +template std::vector clamp(const void* src, size_t len, float min_value, float max_value); + +} // namespace kai::test diff --git a/test/reference/clamp.hpp b/test/reference/clamp.hpp new file mode 100644 index 0000000000000000000000000000000000000000..24d3ac6c74c0acb02c4742ccdeb307e92f7f06bd --- /dev/null +++ b/test/reference/clamp.hpp @@ -0,0 +1,24 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include + +namespace kai::test { + +/// Clamps the matrix. +/// +/// @param[in] src Data buffer of the source matrix. +/// @param[in] len Number of values in the source matrix. +/// @param[in] min_value Lower bound of clamp. +/// @param[in] width Upper bound of clamp. +template +std::vector clamp(const void* src, size_t len, T min_value, T max_value); + +} // namespace kai::test diff --git a/test/reference/matmul.cpp b/test/reference/matmul.cpp index 6ac8fe71233e0c09e6f9c82b475a5f8cefba2694..a1be25566ab745762e47321ce223da0ab3440a9c 100644 --- a/test/reference/matmul.cpp +++ b/test/reference/matmul.cpp @@ -185,6 +185,79 @@ std::vector matmul( return tmp_dst; } +template < + typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale, + typename RhsZeroPoint, typename BiasData, typename BiasScale, typename BiasZeroPoint, typename DstData> +std::vector matmul_nt_t_quantized( + size_t m, size_t n, size_t k, // + const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_height, + size_t lhs_quant_width, // + const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height, + size_t rhs_quant_width, // + const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width) { + const auto lhs_num_quant_per_row = round_up_division(k, lhs_quant_width); + const auto rhs_num_quant_per_row = round_up_division(k, rhs_quant_width); + + std::vector dst(m * n * sizeof(DstData)); + + for (size_t y = 0; y < m; ++y) { + for (size_t x = 0; x < n; ++x) { + DstData acc = 0; + + for (size_t i = 0; i < k; ++i) { + const auto lhs_data_index = y * k + i; + const auto lhs_quant_index = y / lhs_quant_height * lhs_num_quant_per_row + i / lhs_quant_width; + const auto lhs_value = read_array(lhs_data, lhs_data_index); + const auto lhs_scale = lhs_scales != nullptr ? read_array(lhs_scales, lhs_quant_index) + : static_cast(1); + const auto lhs_zero_point = lhs_zero_points != nullptr + ? read_array(lhs_zero_points, lhs_quant_index) + : static_cast(0); + + const auto rhs_data_index = x * k + i; + const auto rhs_quant_index = x / rhs_quant_height * rhs_num_quant_per_row + i / rhs_quant_width; + const auto rhs_value = read_array(rhs_data, rhs_data_index); + const auto rhs_scale = rhs_scales != nullptr ? read_array(rhs_scales, rhs_quant_index) + : static_cast(1); + const auto rhs_zero_point = rhs_zero_points != nullptr + ? read_array(rhs_zero_points, rhs_quant_index) + : static_cast(0); + + acc += (static_cast(lhs_value) - static_cast(lhs_zero_point)) * + static_cast(lhs_scale) * + (static_cast(rhs_value) - static_cast(rhs_zero_point)) * + static_cast(rhs_scale); + } + + if (bias_data != nullptr) { + const auto bias_value = read_array(bias_data, x); + const auto bias_scale = bias_scales != nullptr + ? read_array(bias_scales, x / bias_quant_width) + : static_cast(1); + const auto bias_zero_point = bias_zero_points != nullptr + ? read_array(bias_zero_points, x / bias_quant_width) + : static_cast(0); + + acc += (static_cast(bias_value) - static_cast(bias_zero_point)) * + static_cast(bias_scale); + } + + write_array(dst.data(), y * n + x, acc); + } + } + + return dst; +} + +template std::vector +matmul_nt_t_quantized( + size_t m, size_t n, size_t k, // + const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_height, + size_t lhs_quant_width, // + const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height, + size_t rhs_quant_width, // + const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width); + template < typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale, typename RhsZeroPoint, typename Bias, typename IntAcc, typename DstData> @@ -224,8 +297,8 @@ std::vector matmul_clamp_nt_t( : 0; acc += static_cast( - (static_cast(lhs_value) + static_cast(lhs_zero_point)) * - (static_cast(rhs_value) + static_cast(rhs_zero_point))) * + (static_cast(lhs_value) - static_cast(lhs_zero_point)) * + (static_cast(rhs_value) - static_cast(rhs_zero_point))) * static_cast(lhs_scale) * static_cast(rhs_scale); } @@ -309,8 +382,8 @@ std::vector matmul_clamp_nt_nt( : 0; acc += static_cast( - (static_cast(lhs_value) + static_cast(lhs_zero_point)) * - (static_cast(rhs_value) + static_cast(rhs_zero_point))) * + (static_cast(lhs_value) - static_cast(lhs_zero_point)) * + (static_cast(rhs_value) - static_cast(rhs_zero_point))) * static_cast(lhs_scale) * static_cast(rhs_scale); } diff --git a/test/reference/matmul.hpp b/test/reference/matmul.hpp index 88a0729ff662f67b25321034a6981818c13edfe7..9a8ce9f809558a77f30c66912ba3dc6206fb596a 100644 --- a/test/reference/matmul.hpp +++ b/test/reference/matmul.hpp @@ -145,4 +145,15 @@ std::vector matmul_clamp_nt_nt( const void* biases, // DstData min_value, DstData max_value); +template < + typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale, + typename RhsZeroPoint, typename BiasData, typename BiasScale, typename BiasZeroPoint, typename DstData> +std::vector matmul_nt_t_quantized( + size_t m, size_t n, size_t k, // + const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_height, + size_t lhs_quant_width, // + const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height, + size_t rhs_quant_width, // + const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width); + } // namespace kai::test diff --git a/test/reference/matmul_pack.cpp b/test/reference/matmul_pack.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f04de8dc632e8a32bfc090bcd202879216da7248 --- /dev/null +++ b/test/reference/matmul_pack.cpp @@ -0,0 +1,54 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/reference/matmul_pack.hpp" + +#include +#include +#include +#include + +#include "test/common/round.hpp" +#include "test/reference/binary_elementwise.hpp" +#include "test/reference/pack.hpp" +#include "test/reference/reduce.hpp" +#include "test/reference/reorder.hpp" + +namespace kai::test { + +template +std::vector matmul_pack_rhs_nxk_static_quantized( + const void* data, const void* scales, Scale lhs_scale, Scale dst_scale, const void* biases, + ZeroPoint lhs_zero_point, size_t n, size_t k, size_t block_height, size_t block_width) { + // The RHS data matrix is reordered according to the blocking parameters. + const auto reordered_data = reorder_block(data, n, k, block_height, block_width); + + // The effective per-channel scale: + // final_scales[n_index] = lhs_scale * rhs_scales[n_index] / dst_scale. + const auto scale_multiplier = lhs_scale / dst_scale; + auto combined_scales = mul(scales, 1, n, &scale_multiplier, 1, 1); + combined_scales.resize(round_up_multiple(n, block_height) * sizeof(Scale)); // Pads with 0s. + + // The effective per-channel biases: + // final_biases[n_index] = biases[n_index] - lhs_zero_point * sum(data[n_index, :]). + const auto row_sum = reduce_add_x(data, n, k); + const auto row_sum_times_lhs_zp = mul(row_sum.data(), n, k, &lhs_zero_point, 1, 1); + auto combined_biases = sub(biases, 1, n, row_sum_times_lhs_zp.data(), 1, n); + combined_biases.resize(round_up_multiple(n, block_height) * sizeof(ZeroPoint)); // Pads with 0s. + + // Packs the effective biases followed by the data block followed by the effective scales for the block. + auto packed_rhs = pack_zero_points_data_scales_per_block( + combined_biases.data(), reordered_data.data(), combined_scales.data(), round_up_division(n, block_height), + block_height, block_height * round_up_multiple(k, block_width), block_height); + + return packed_rhs; +} + +template std::vector matmul_pack_rhs_nxk_static_quantized( + const void* data, const void* scales, float lhs_scale, float dst_scale, const void* biases, int32_t lhs_zero_point, + size_t n, size_t k, size_t block_height, size_t block_width); + +} // namespace kai::test diff --git a/test/reference/matmul_pack.hpp b/test/reference/matmul_pack.hpp new file mode 100644 index 0000000000000000000000000000000000000000..30646c9875d8be9ef80ad47362a203ed4c68881c --- /dev/null +++ b/test/reference/matmul_pack.hpp @@ -0,0 +1,44 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include + +namespace kai::test { + +/// Packs the RHS buffer for static quantized GeMM. +/// +/// The RHS matrix must be transposed. +/// +/// This function can be used when the following conditions are met: +/// * LHS, RHS and DST data types have the same size and are quantized. +/// * LHS is asymmetric per-tensor, RHS is symmetric per-channel and DST is asymmetric per-tensor. +/// +/// @tparam Data The data type of the RHS matrix. +/// @tparam Scale The data type of the quantization scales. +/// @tparam ZeroPoint The data type of the quantization zero points and the operator biases. +/// +/// @param[in] data The data buffer of the RHS matrix. +/// @param[in] scales The quantization scales of the RHS matrix. +/// @param[in] lhs_scale The quantization scale of the LHS matrix. +/// @param[in] dst_scale The quantization scale of the DST matrix. +/// @param[in] biases The biases of the operator. +/// @param[in] lhs_zero_point The quantization zero point of the LHS matrix. +/// @param[in] n The number of columns of the non-transposed RHS matrix. +/// @param[in] k The number of rows of the non-transposed RHS matrix. +/// @param[in] block_height The number of rows of a data block (N dimension). +/// @param[in] block_width The number of columns of a data block (K dimension). +/// +/// @return The packed RHS. +template +std::vector matmul_pack_rhs_nxk_static_quantized( + const void* data, const void* scales, Scale lhs_scale, Scale dst_scale, const void* biases, + ZeroPoint lhs_zero_point, size_t n, size_t k, size_t block_height, size_t block_width); + +} // namespace kai::test diff --git a/test/reference/pack.cpp b/test/reference/pack.cpp index 25c77a6382dcc02c8fa2ae62efe2022943e73288..ed8dd09deead44f2e7e90ea2ad84a59012c1b3d2 100644 --- a/test/reference/pack.cpp +++ b/test/reference/pack.cpp @@ -255,6 +255,56 @@ std::vector pack_data_scales( return dst; } +template +std::vector pack_zero_points_data_scales_per_block( + const void* zero_points, const void* data, const void* scales, size_t num_blocks, size_t block_num_zero_points, + size_t block_num_data, size_t block_num_scales) { + // Only data is allowed to be sub-byte. + KAI_ASSUME(size_in_bits % 8 == 0); + KAI_ASSUME(size_in_bits % 8 == 0); + + // Checks for memory alignment. + KAI_ASSUME(size_in_bits % size_in_bits == 0); + KAI_ASSUME( + (block_num_zero_points * size_in_bits + block_num_data * size_in_bits) % size_in_bits == + 0); + KAI_ASSUME( + (block_num_data * size_in_bits + block_num_scales * size_in_bits) % size_in_bits == 0); + + std::vector dst(round_up_division( + num_blocks * + (block_num_zero_points * size_in_bits + block_num_data * size_in_bits + + block_num_scales * size_in_bits), + 8)); + auto* dst_ptr = dst.data(); + + for (size_t block_no = 0; block_no < num_blocks; ++block_no) { + for (size_t i = 0; i < block_num_zero_points; ++i) { + write_array( + dst_ptr, i, read_array(zero_points, block_no * block_num_zero_points + i)); + } + dst_ptr += block_num_zero_points * sizeof(ZeroPoint); + + for (size_t i = 0; i < block_num_data; ++i) { + write_array(dst_ptr, i, read_array(data, block_no * block_num_data + i)); + } + dst_ptr += round_up_division(block_num_data * size_in_bits, 8); + + for (size_t i = 0; i < block_num_scales; ++i) { + write_array(dst_ptr, i, read_array(scales, block_no * block_num_scales + i)); + } + dst_ptr += block_num_scales * sizeof(Scale); + } + + KAI_ASSERT(dst_ptr == &*dst.end()); + + return dst; +} + +template std::vector pack_zero_points_data_scales_per_block( + const void* zero_points, const void* data, const void* scales, size_t num_blocks, size_t block_num_zero_points, + size_t block_num_data, size_t block_num_scales); + template std::vector pack_data_scales_interleave_block( const void* data, const void* scales, size_t height, size_t width, size_t quant_width) { diff --git a/test/reference/pack.hpp b/test/reference/pack.hpp index 128ad0400b09a4a3ae8e1d503fc8ae0824837ebe..63c94d58f7e82e1790e971a55c6137e59d906abc 100644 --- a/test/reference/pack.hpp +++ b/test/reference/pack.hpp @@ -79,6 +79,70 @@ template std::vector pack_data_scales( const void* data, const void* scales, size_t height, size_t width, size_t quant_width); +/// Packs the zero point, data and scale into a single buffer. +/// +/// ``` +/// Data matrix: +/// +/// +-----------------+ +/// | q00 q01 q02 q03 | +/// | q10 q11 q12 q13 | +/// | q20 q21 q22 q23 | +/// | q30 q31 q32 q33 | +/// | ............... | +/// : ............... : +/// +/// Scales for each row: +/// +/// +----+ +/// | s0 | +/// | s1 | +/// | s2 | +/// | s3 | +/// | .. | +/// : .. : +/// +/// Zero points for each row: +/// +/// +----+ +/// | z0 | +/// | z1 | +/// | z2 | +/// | z3 | +/// | .. | +/// : .. : +/// ``` +/// +/// The packed data has each zero point followed by the data row followed by the scale. +/// +/// ``` +/// Packed data: +/// +/// +----+-----------------+----+ +/// | z0 | q00 q01 q02 q03 | s0 | +/// | z1 | q10 q11 q12 q13 | s1 | +/// | z2 | q20 q21 q22 q23 | s2 | +/// | z3 | q30 q31 q32 q33 | s3 | +/// | .. | ............... | .. | +/// : .. : ............... : .. : +/// ``` +/// +/// @tparam Data The data type of the data. +/// @tparam Scale The data type of the scale. +/// @tparam ZeroPoint The data type of the zero point. +/// +/// @param[in] data The data buffer. +/// @param[in] scales The scales buffer. +/// @param[in] zero_points The zero points buffer. +/// @param[in] height The number of rows. +/// @param[in] width The number of columns. +/// +/// @return The packed data buffer. +template +std::vector pack_zero_points_data_scales_per_block( + const void* zero_points, const void* data, const void* scales, size_t num_blocks, size_t block_num_zero_points, + size_t block_num_data, size_t block_num_scales); + /// Packs the quantized data and the quantization scale into a single buffer. /// /// ``` diff --git a/test/reference/quantize.cpp b/test/reference/quantize.cpp index 1a2acddfc080ac2ca1192cfb95d6f0156bc3d49c..82a09fd7614db1b9f25a927414f7a49ccc177c56 100644 --- a/test/reference/quantize.cpp +++ b/test/reference/quantize.cpp @@ -48,11 +48,13 @@ std::tuple get_scale_zero_point_from_range(FloatData min_v const FloatData scaled_max = max_value / scale; const FloatData zero_point_f = -(scaled_min + q_min) < scaled_max + q_max ? scaled_min - q_min : scaled_max - q_max; - const ZeroPoint zero_point = round_to_nearest_even(zero_point_f); + const ZeroPoint zero_point = -round_to_nearest_even(zero_point_f); return {scale, zero_point}; } +} // namespace + template IntType quantize_symmetric(float value, float scale) { const auto inv_scale = scale != 0 ? 1.0F / scale : 0.0F; @@ -68,12 +70,12 @@ IntType quantize_symmetric(float value, float scale) { template IntType quantize_asymmetric(FloatType value, FloatType scale, ZeroPointType zero_point) { const auto inv_scale = scale != 0 ? 1.0F / scale : 0.0F; - auto quantized_value = round_to_nearest_even(value * inv_scale) - zero_point; + auto quantized_value = round_to_nearest_even(value * inv_scale) + zero_point; return static_cast( std::clamp(quantized_value, numeric_lowest, numeric_highest)); } -} // namespace +template int8_t quantize_asymmetric(float value, float scale, int32_t zero_point); template std::vector compute_symmetric_per_block_quantization_info( @@ -102,7 +104,8 @@ std::vector compute_symmetric_per_block_quantization_info( } } - const auto scale = max_abs / ((1 << (size_in_bits - 1)) - 1); + const auto scale = + max_abs / static_cast((static_cast(1) << (size_in_bits - 1)) - 1); // Stores the scales. write_array(scales.data(), y * num_quant_packets_x + x_quant / quant_width, scale); @@ -144,6 +147,9 @@ std::vector quantize_symmetric_per_block( return data; } +template std::vector quantize_symmetric_per_block( + const void* src, const void* scales, size_t height, size_t width, size_t quant_width); + template std::tuple, std::vector> quantize_symmetric_per_block_dynamic( const void* src, size_t height, size_t width, size_t quant_width) { @@ -172,6 +178,8 @@ template std::tuple, std::vector> quantize_symmetr float, int8_t, Float16>(const void* src, size_t height, size_t width, size_t quant_width); template std::tuple, std::vector> quantize_symmetric_per_block_dynamic< float, int8_t, float>(const void* src, size_t height, size_t width, size_t quant_width); +template std::tuple, std::vector> quantize_symmetric_per_block_dynamic< + float, int32_t, float>(const void* src, size_t height, size_t width, size_t quant_width); template std::tuple, std::vector> compute_asymmetric_per_block_quantization_info( diff --git a/test/reference/quantize.hpp b/test/reference/quantize.hpp index 35532acc2345f9494607d438edf355da75a8ded0..3e2f162dc4dcc43a432753bcfd232beed00fe758 100644 --- a/test/reference/quantize.hpp +++ b/test/reference/quantize.hpp @@ -19,6 +19,20 @@ enum class QuantizationMethod : uint32_t { PER_ROW, ///< Per-row, i.e. one quantization scale and zero point for each row. }; +/// Quantized a float value to an integer datatype using a provided scale. +/// +/// @tparam IntType Quantized integer datatype. +/// +/// @param[in] float The value to quantize +/// @param[in] scale The scale used to quantize the provided float value. +/// +/// @return The quantized data matrix, the quantization scale matrix and the quantization zero point matrix. +template +IntType quantize_symmetric(float value, float scale); + +template +IntType quantize_asymmetric(FloatType value, FloatType scale, ZeroPointType zero_point); + /// Computes the quantization information using symmetric per-block quantization method. /// /// The input matrix is divided into quantization blocks of the same size. diff --git a/test/reference/reduce.cpp b/test/reference/reduce.cpp index 0e83b9bfdc7cf226635100e03cdd84c19497e207..d4935c3f94a13b7a0b25b22d04532f22ca85d362 100644 --- a/test/reference/reduce.cpp +++ b/test/reference/reduce.cpp @@ -6,6 +6,7 @@ #include "test/reference/reduce.hpp" +#include #include #include #include @@ -15,6 +16,7 @@ #include "test/common/data_type.hpp" #include "test/common/int4.hpp" #include "test/common/memory.hpp" +#include "test/common/round.hpp" namespace kai::test { @@ -110,4 +112,53 @@ std::vector reduce_add( return reduce_any_op(src, src_format, height, width, dst_format, dimension); } +template +std::vector reduce_add_x(const void* src, size_t height, size_t width) { + std::vector dst(round_up_division(height * size_in_bits, 8)); + + for (size_t y = 0; y < height; ++y) { + Accumulator acc = 0; + + for (size_t x = 0; x < width; ++x) { + acc += static_cast(read_array(src, y * width + x)); + } + + write_array(dst.data(), y, acc); + } + + return dst; +} + +template std::vector reduce_add_x(const void* src, size_t height, size_t width); + +template +T reduce_min(const void* src, size_t len) { + KAI_ASSUME(len > 0); + + T min = read_array(src, 0); + + for (size_t i = 1; i < len; ++i) { + min = std::min(min, read_array(src, i)); + } + + return min; +} + +template float reduce_min(const void* src, size_t len); + +template +T reduce_max(const void* src, size_t len) { + KAI_ASSUME(len > 0); + + T max = read_array(src, 0); + + for (size_t i = 1; i < len; ++i) { + max = std::max(max, read_array(src, i)); + } + + return max; +} + +template float reduce_max(const void* src, size_t len); + } // namespace kai::test diff --git a/test/reference/reduce.hpp b/test/reference/reduce.hpp index f6ba197a0f5bdcf9ef16a35fe4ee4e74cfe50450..a3ccec7c72a28f17e3b7f74c9120f40712ce5c5f 100644 --- a/test/reference/reduce.hpp +++ b/test/reference/reduce.hpp @@ -33,4 +33,39 @@ std::vector reduce_add( const void* src, const DataFormat& src_format, size_t height, size_t width, const DataFormat& dst_format, size_t dimension); +/// Accumulates the matrix along the first dimension. +/// +/// @tparam Value The data type of the matrix value. +/// @tparam Accumulator The data type of the accumulator. +/// +/// @param[in] src The input data. +/// @param[in] height The number of rows of the input matrix. +/// @param[in] width The number of columns of the input matrix. +/// +/// @return The vector containing the sum of each input matrix row. +template +std::vector reduce_add_x(const void* src, size_t height, size_t width); + +/// Retrieve the minimum value in a provided matrix. +/// +/// @tparam T Datatype of source matrix +/// +/// @param[in] src The input data +/// @param[in] len The number of values within the source matrix. +/// +/// @return The quantized data matrix, the quantization scale matrix and the quantization zero point matrix. +template +T reduce_min(const void* src, size_t len); + +/// Retrieve the maximum value in a provided matrix. +/// +/// @tparam T Datatyoe of source matrix +/// +/// @param[in] src The input data +/// @param[in] len The number of values within the source matrix. +/// +/// @return The quantized data matrix, the quantization scale matrix and the quantization zero point matrix. +template +T reduce_max(const void* src, size_t len); + } // namespace kai::test diff --git a/test/reference/reorder.cpp b/test/reference/reorder.cpp new file mode 100644 index 0000000000000000000000000000000000000000..564f96f6ae6f9dbd4cd7fd7d24834bf079f6fe9e --- /dev/null +++ b/test/reference/reorder.cpp @@ -0,0 +1,50 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/reference/reorder.hpp" + +#include +#include +#include + +#include "test/common/memory.hpp" +#include "test/common/round.hpp" + +namespace kai::test { + +template +std::vector reorder_block( + const void* src, size_t height, size_t width, size_t block_height, size_t block_width) { + const auto num_dst_elements = round_up_multiple(height, block_height) * round_up_multiple(width, block_width); + const auto dst_size = round_up_division(num_dst_elements * size_in_bits, 8); + + std::vector dst(dst_size); + size_t dst_index = 0; + + for (size_t y_block = 0; y_block < height; y_block += block_height) { + for (size_t x_block = 0; x_block < width; x_block += block_width) { + for (size_t y_element = 0; y_element < block_height; ++y_element) { + for (size_t x_element = 0; x_element < block_width; ++x_element) { + const auto y = y_block + y_element; + const auto x = x_block + x_element; + + if (y < height && x < width) { + write_array(dst.data(), dst_index, read_array(src, y * width + x)); + } + + ++dst_index; + } + } + } + } + + return dst; +} + +template std::vector reorder_block( + const void* src, size_t height, size_t width, size_t block_height, size_t block_width); + +} // namespace kai::test diff --git a/test/reference/reorder.hpp b/test/reference/reorder.hpp new file mode 100644 index 0000000000000000000000000000000000000000..48449e37c197d710b5cc73b6d28f3ae862acdef0 --- /dev/null +++ b/test/reference/reorder.hpp @@ -0,0 +1,72 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include + +namespace kai::test { + +/// Reorders the input matrix block by block. +/// +/// Example: +/// +/// The input matrix: 5x7. +/// +/// ``` +/// +-----------------------------+ +/// | a00 a01 a02 a03 a04 a05 a06 | +/// | a10 a11 a12 a13 a14 a15 a16 | +/// | a20 a21 a22 a23 a24 a25 a26 | +/// | a30 a31 a32 a33 a34 a35 a36 | +/// | a40 a41 a42 a43 a44 a45 a46 | +/// +-----------------------------+ +/// ``` +/// +/// The matrix is divided into blocks of 2x3. +/// At the right and bottom edges, the partial blocks are padded with 0s. +/// +/// ``` +// +-------------+-------------+-------------+ +/// | a00 a01 a02 | a03 a04 a05 | a06 0 0 | +/// | a10 a11 a12 | a13 a14 a15 | a16 0 0 | +/// +-------------+-------------+-------------+ +/// | a20 a21 a22 | a23 a24 a25 | a26 0 0 | +/// | a30 a31 a32 | a33 a34 a35 | a36 0 0 | +/// +-------------+-------------+-------------+ +/// | a40 a41 a42 | a43 a44 a45 | a46 0 0 | +/// | 0 0 0 | 0 0 0 | 0 0 0 | +/// +-------------+-------------+-------------+ +/// ``` +/// +/// Each block is then flatten to get the final reordered matrix: +/// +/// ``` +/// +-------------------------+-------------------------+-------------------------+ +/// | a00 a01 a02 a10 a11 a12 | a03 a04 a05 a13 a14 a15 | a06 0 0 a16 0 0 | +/// +-------------------------+-------------------------+-------------------------+ +/// | a20 a21 a22 a30 a31 a32 | a23 a24 a25 a33 a34 a35 | a26 0 0 a36 0 0 | +/// +-------------------------+-------------------------+-------------------------+ +/// | a40 a41 a42 0 0 0 | a43 a44 a45 0 0 0 | a46 0 0 0 0 0 | +/// +-------------------------+-------------------------+-------------------------+ +/// +/// @tparam T The data type. +/// +/// @param[in] src The input data. +/// @param[in] height The number of rows of the input matrix. +/// @param[in] width The number of columns of the input matrix. +/// @param[in] block_height The number of rows of a block. +/// @param[in] block_width The number of columns of a block. +/// +/// @param[in] The reordered matrix. +/// ``` +template +std::vector reorder_block( + const void* src, size_t height, size_t width, size_t block_height, size_t block_width); + +} // namespace kai::test diff --git a/test/reference/transpose.cpp b/test/reference/transpose.cpp index d0ca5590e7e0e2dfc33ffe63ee017a4796f99e98..84958422d42600d4f1b97012fcbae78ceafd95e2 100644 --- a/test/reference/transpose.cpp +++ b/test/reference/transpose.cpp @@ -14,6 +14,7 @@ #include "kai/kai_common.h" #include "test/common/data_type.hpp" #include "test/common/memory.hpp" +#include "test/common/round.hpp" namespace kai::test { @@ -38,7 +39,7 @@ std::vector transpose(const void* data, DataType data_type, size_t heig } template -std::vector transpose( +std::vector transpose_with_padding( const void* data, const size_t height, const size_t width, const size_t src_stride, const size_t dst_stride, const size_t dst_size) { std::vector output(dst_size); @@ -53,11 +54,28 @@ std::vector transpose( return output; } -template std::vector transpose( +template std::vector transpose_with_padding( const void* data, const size_t height, const size_t width, const size_t src_stride, const size_t dst_stride, const size_t dst_size); -template std::vector transpose( +template std::vector transpose_with_padding( const void* data, const size_t height, const size_t width, const size_t src_stride, const size_t dst_stride, const size_t dst_size); + +template +std::vector transpose(const void* src, size_t height, size_t width) { + std::vector dst(round_up_division(height * width * size_in_bits, 8)); + + for (size_t y = 0; y < width; ++y) { + for (size_t x = 0; x < height; ++x) { + write_array(dst.data(), y * height + x, read_array(src, x * width + y)); + } + } + + return dst; +} + +template std::vector transpose(const void* src, size_t height, size_t width); +template std::vector transpose(const void* src, size_t height, size_t width); + } // namespace kai::test diff --git a/test/reference/transpose.hpp b/test/reference/transpose.hpp index 63d94b5bc0ef542c3a03b9d0d6585fd3923595a9..306bc89dfd3e768a2a9134e46889d2d6d9913d98 100644 --- a/test/reference/transpose.hpp +++ b/test/reference/transpose.hpp @@ -37,7 +37,18 @@ std::vector transpose(const void* data, DataType data_type, size_t heig /// @return The transposed matrix. /// template -std::vector transpose( +std::vector transpose_with_padding( const void* data, size_t height, size_t width, size_t src_stride, size_t dst_stride, size_t dst_size); +/// +/// @tparam T The data type. +/// +/// @param[in] src The data buffer of the source matrix. +/// @param[in] height The number of rows of the source matrix. +/// @param[in] width The number of columns of the source matrix. +/// +/// @return The transposed matrix. +template +std::vector transpose(const void* src, size_t height, size_t width); + } // namespace kai::test diff --git a/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp b/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp index 0155d7d80701810f2c9ee523e21100cfd228eb41..c24361f544710cb967ac05e3f302901aa6fca68d 100644 --- a/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp +++ b/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp @@ -182,7 +182,7 @@ TEST_P(MatMulTest_f32_qmatmul_clamp_f32_qai8dxp_qsi4c32p, EndToEnd_RHS_kxn) { const auto [ref_rhs_qsi4_transposed, ref_rhs_scales] = quantize_symmetric_per_block_dynamic(ref_rhs_transposed.data(), N, K, bl); - auto ref_rhs_qsi4 = transpose( + auto ref_rhs_qsi4 = transpose_with_padding( ref_rhs_qsi4_transposed.data(), N, K, ref_rhs_qsi4_nxk_stride, ref_rhs_qsi4_kxn_stride, ref_rhs_qsi4_kxn_size_bytes); diff --git a/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp b/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp index b245e7f55174a0dec670d6d2daba6f15661470f4..ecc7f70318885a586fc77ed00ab0b99b03a21a59 100644 --- a/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp +++ b/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp @@ -253,7 +253,7 @@ TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_kxn_qsi4cx) { const auto [ref_rhs_qsi4_transposed, ref_rhs_scales] = quantize_symmetric_per_block_dynamic(ref_rhs.data(), N, K, K); - const auto ref_rhs_qsi4 = transpose( + const auto ref_rhs_qsi4 = transpose_with_padding( ref_rhs_qsi4_transposed.data(), N, K, ref_rhs_qsi4_nxk_stride, ref_rhs_qsi4_kxn_stride, ref_rhs_qsi4_kxn_size_bytes); @@ -342,7 +342,7 @@ TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_kxn_qsu4cx) { const auto [ref_rhs_qsi4_transposed, ref_rhs_scales] = quantize_symmetric_per_block_dynamic(ref_rhs.data(), N, K, K); - const auto ref_rhs_qsi4 = transpose( + const auto ref_rhs_qsi4 = transpose_with_padding( ref_rhs_qsi4_transposed.data(), N, K, ref_rhs_qsi4_nxk_stride, ref_rhs_qsi4_kxn_stride, ref_rhs_qsi4_kxn_size_bytes); diff --git a/test/tests/matmul_clamp_f32_qai8dxp_qsi8cxp_test.cpp b/test/tests/matmul_clamp_f32_qai8dxp_qsi8cxp_test.cpp index 6360c4331efbf02bc7996f3188a146eca59456d9..db3c530f386daf5640a448edfc12d2a8188ad86d 100644 --- a/test/tests/matmul_clamp_f32_qai8dxp_qsi8cxp_test.cpp +++ b/test/tests/matmul_clamp_f32_qai8dxp_qsi8cxp_test.cpp @@ -162,7 +162,7 @@ TEST_P(MatMulTest_f32_qai8dxp_qsi8cxp, EndToEnd_RHS_kxn_qsi8cx) { const auto [ref_rhs_qsi8_transposed, ref_rhs_scales] = quantize_symmetric_per_block_dynamic(ref_rhs.data(), N, K, K); - const auto ref_rhs_qsi8 = transpose( + const auto ref_rhs_qsi8 = transpose_with_padding( ref_rhs_qsi8_transposed.data(), N, K, ref_rhs_qsi8_nxk_stride, ref_rhs_qsi8_kxn_stride, ref_rhs_qsi8_kxn_size_bytes); diff --git a/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp b/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..985e81669ca4dab4cee049721b51952f12a50f15 --- /dev/null +++ b/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp @@ -0,0 +1,379 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include +#include + +#include "kai/kai_common.h" +#include "kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h" +#include "kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h" +#include "test/common/cpu_info.hpp" +#include "test/common/matrix_portion.hpp" +#include "test/common/memory.hpp" +#include "test/common/rect.hpp" +#include "test/common/sme.hpp" +#include "test/reference/binary_elementwise.hpp" +#include "test/reference/clamp.hpp" +#include "test/reference/fill.hpp" +#include "test/reference/matmul.hpp" +#include "test/reference/matmul_pack.hpp" +#include "test/reference/quantize.hpp" +#include "test/reference/reduce.hpp" +#include "test/reference/reorder.hpp" +#include "test/reference/transpose.hpp" + +namespace kai::test { + +namespace { + +struct GemmVariant { + size_t acc_height; + size_t acc_width; + size_t acc_fanin; + + bool (*fn_is_supported)(); + + size_t (*fn_pack_lhs_get_m_step)(size_t mr); + size_t (*fn_pack_lhs_get_lhs_offset)(size_t m_idx, size_t lhs_stride); + size_t (*fn_pack_lhs_get_packed_lhs_offset)(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr); + size_t (*fn_pack_lhs_get_packed_lhs_size)(size_t m, size_t k, size_t mr, size_t kr, size_t sr); + void (*fn_pack_lhs_run)( + size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride, + void* lhs_packed); + + size_t (*fn_pack_rhs_get_n_step)(); + size_t (*fn_pack_rhs_get_rhs_offset)(size_t n_idx); + size_t (*fn_pack_rhs_get_bias_offset)(size_t n_idx); + size_t (*fn_pack_rhs_get_scale_offset)(size_t n_idx); + size_t (*fn_pack_rhs_get_packed_rhs_offset)(size_t n_idx, size_t k); + size_t (*fn_pack_rhs_get_packed_rhs_size)(size_t n, size_t k); + void (*fn_pack_rhs_run)( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, + const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, + const struct kai_rhs_pack_qsi8_params* params); + + size_t (*fn_main_get_m_step)(); + size_t (*fn_main_get_n_step)(); + size_t (*fn_main_get_mr)(); + size_t (*fn_main_get_nr)(); + size_t (*fn_main_get_kr)(); + size_t (*fn_main_get_sr)(); + size_t (*fn_main_get_packed_lhs_offset)(size_t m_idx, size_t k); + size_t (*fn_main_get_packed_rhs_offset)(size_t n_idx, size_t k); + size_t (*fn_main_get_dst_offset)(size_t m_idx, size_t n_idx, size_t dst_stride); + size_t (*fn_main_get_dst_size)(size_t m, size_t n); + void (*fn_main_run)( + size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row, + size_t dst_stride_col, const kai_matmul_requantize32_params* params); +}; + +struct GemmShape { + size_t m; + size_t n; + size_t k; +}; + +const std::array gemm_variants = { + GemmVariant{ + .acc_height = 2 * get_sme_vector_length(), + .acc_width = 2 * get_sme_vector_length(), + .acc_fanin = sizeof(int32_t) / sizeof(int8_t), + + .fn_is_supported = cpu_has_sme2, + + .fn_pack_lhs_get_m_step = kai_get_m_step_lhs_pack_x8p2vlx4_x8_sme, + .fn_pack_lhs_get_lhs_offset = kai_get_lhs_offset_lhs_pack_x8p2vlx4_x8_sme, + .fn_pack_lhs_get_packed_lhs_offset = kai_get_lhs_packed_offset_lhs_pack_x8p2vlx4_x8_sme, + .fn_pack_lhs_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_x8p2vlx4_x8_sme, + .fn_pack_lhs_run = kai_run_lhs_pack_x8p2vlx4_x8_sme, + + .fn_pack_rhs_get_n_step = kai_get_n_step_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, + .fn_pack_rhs_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, + .fn_pack_rhs_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, + .fn_pack_rhs_get_scale_offset = kai_get_scale_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, + .fn_pack_rhs_get_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, + .fn_pack_rhs_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, + .fn_pack_rhs_run = kai_run_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, + + .fn_main_get_m_step = kai_get_m_step_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .fn_main_get_n_step = kai_get_n_step_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .fn_main_get_mr = kai_get_mr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .fn_main_get_nr = kai_get_nr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .fn_main_get_kr = kai_get_kr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .fn_main_get_sr = kai_get_sr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .fn_main_get_packed_lhs_offset = + kai_get_lhs_packed_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .fn_main_get_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .fn_main_get_dst_offset = kai_get_dst_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .fn_main_get_dst_size = kai_get_dst_size_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .fn_main_run = kai_run_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + }, +}; + +constexpr float output_clamp_rate = 0.1F; // Clamping 10% the range of the output. + +const std::array gemm_shapes = { + GemmShape{1, 1, 1}, // + GemmShape{ + 2 * get_sme_vector_length(), 2 * get_sme_vector_length(), + sizeof(int32_t) / sizeof(int8_t)}, // + GemmShape{20, 30, 40}, // + GemmShape{1, 49, 21}, // + GemmShape{23, 1, 43}, // + GemmShape{32, 14, 1}, // + GemmShape{123, 85, 45}, // + GemmShape{130, 130, 6}, +}; + +const std::array output_portions = { + MatrixPortion(0, 0, 1, 1), // Full matrix. + MatrixPortion(0, 0, 0.25, 0.25), // Top-left corner. + MatrixPortion(0.75, 0.75, 1, 1), // Bottom-right corner. +}; + +void run_test(const GemmShape& shape, const GemmVariant& variant, const MatrixPortion& output_portion) { + const uint64_t seed = 0; + + if (!variant.fn_is_supported()) { + GTEST_SKIP(); + } + + // ============================================================ + // Test the packing and scheduling parameters + // ============================================================ + + const auto imp_mr = variant.fn_main_get_mr(); + const auto imp_nr = variant.fn_main_get_nr(); + const auto imp_kr = variant.fn_main_get_kr(); + const auto imp_sr = variant.fn_main_get_sr(); + + ASSERT_EQ(imp_mr, variant.acc_height); + ASSERT_EQ(imp_nr, variant.acc_width); + ASSERT_EQ(imp_kr, variant.acc_fanin); + ASSERT_EQ(imp_sr, 1); + + const auto imp_m_step = variant.fn_main_get_m_step(); + const auto imp_n_step = variant.fn_main_get_n_step(); + + ASSERT_EQ(imp_m_step, variant.acc_height); + ASSERT_EQ(imp_n_step, variant.acc_width); + + // ============================================================ + // Calculates the output area under test + // ============================================================ + + const auto output_area = output_portion.compute_portion(shape.m, shape.n, variant.acc_height, variant.acc_width); + + // ============================================================ + // Generates input and reference output data + // ============================================================ + + // Generates the input data in floating-point. + const auto lhs_f32 = fill_random(shape.m * shape.k, seed + 0); + const auto rhs_f32 = fill_random(shape.k * shape.n, seed + 1); + const auto bias_f32 = fill_random(shape.n, seed + 2); + + // Quantizes the input data. + // * LHS: 8-bit asymmetric per-matrix quantization. + // * RHS: 8-bit symmetric per-channel quantization. + // * Bias: 32-bit symmetric per-channel quantization. + const auto [lhs_qai8, lhs_qai8_scales, lhs_qai8_zero_points] = + quantize_asymmetric_per_block_dynamic( + lhs_f32.data(), 1, shape.m * shape.k, shape.m * shape.k); + const auto lhs_scale = read_array(lhs_qai8_scales.data(), 0); + const auto lhs_zero_point = read_array(lhs_qai8_zero_points.data(), 0); + + const auto rhs_f32_t = transpose(rhs_f32.data(), shape.k, shape.n); + const auto [rhs_qsi8_t, rhs_scales] = + quantize_symmetric_per_block_dynamic(rhs_f32_t.data(), shape.n, shape.k, shape.k); + const auto rhs_qsi8 = transpose(rhs_qsi8_t.data(), shape.n, shape.k); + + const auto bias_scale = mul(&lhs_scale, 1, 1, rhs_scales.data(), 1, shape.n); + const auto bias_qsi32 = + quantize_symmetric_per_block(bias_f32.data(), bias_scale.data(), shape.n, 1, 1); + + // Runs the reference implementation of matmul to produce floating-point result. + const auto ref_dst_f32 = + matmul_nt_t_quantized( + shape.m, shape.n, shape.k, lhs_qai8.data(), &lhs_scale, &lhs_zero_point, shape.m, shape.k, + rhs_qsi8_t.data(), rhs_scales.data(), nullptr, 1, shape.k, bias_qsi32.data(), bias_scale.data(), nullptr, + 1); + + // Computes the output quantization information and clamping limits. + // + // To get a realistic value for the output quantization information and clamping limits + // and avoid uncontrolled saturation problem, these information will be calculated + // based on the reference floating-point output. + // + // The clamping limits will be slightly narrower than the actual range of the output + // so that a portion of the output will be clampped. + const auto [dst_scales, dst_zero_points] = + compute_asymmetric_per_block_quantization_info( + ref_dst_f32.data(), 1, shape.m * shape.n, shape.m * shape.n); + const auto dst_scale = read_array(dst_scales.data(), 0); + const auto dst_zero_point = read_array(dst_zero_points.data(), 0); + + const auto ref_dst_f32_min = reduce_min(ref_dst_f32.data(), shape.m * shape.n); + const auto ref_dst_f32_max = reduce_max(ref_dst_f32.data(), shape.m * shape.n); + const auto ref_dst_f32_range = ref_dst_f32_max - ref_dst_f32_min; + + const auto ref_dst_f32_clamp_min = ref_dst_f32_min + ref_dst_f32_range * output_clamp_rate / 2; + const auto ref_dst_f32_clamp_max = ref_dst_f32_max - ref_dst_f32_range * output_clamp_rate / 2; + const auto dst_qai8_clamp_min = + quantize_asymmetric(ref_dst_f32_clamp_min, dst_scale, dst_zero_point); + const auto dst_qai8_clamp_max = + quantize_asymmetric(ref_dst_f32_clamp_max, dst_scale, dst_zero_point); + + // Clamps and quantizes the reference output matrix. + const auto ref_dst_f32_clamped = + clamp(ref_dst_f32.data(), shape.m * shape.n, ref_dst_f32_clamp_min, ref_dst_f32_clamp_max); + const auto ref_dst_qsi8_clamped = quantize_asymmetric_per_block( + ref_dst_f32_clamped.data(), &dst_scale, &dst_zero_point, 1, shape.m * shape.n, shape.m * shape.n); + + // Runs the reference implementation of the packing functions. + // + // The reference packing functions cannot be executed earlier + // because we need the reference floating-point output first to have + // the quantization information. + const auto ref_packed_lhs = + reorder_block(lhs_qai8.data(), shape.m, shape.k, variant.acc_height, variant.acc_fanin); + + const auto ref_packed_rhs = matmul_pack_rhs_nxk_static_quantized( + rhs_qsi8_t.data(), rhs_scales.data(), lhs_scale, dst_scale, bias_qsi32.data(), lhs_zero_point, shape.n, shape.k, + variant.acc_width, variant.acc_fanin); + + // ============================================================ + // Runs the optimized implementation and checks for correctness + // ============================================================ + + // Runs the optimized implementation of LHS packing. + const auto imp_packed_lhs_size = + variant.fn_pack_lhs_get_packed_lhs_size(shape.m, shape.k, variant.acc_height, variant.acc_fanin, 1); + ASSERT_EQ(imp_packed_lhs_size, ref_packed_lhs.size()); + std::vector imp_packed_lhs(imp_packed_lhs_size); + + { + const auto imp_lhs_offset = + variant.fn_pack_lhs_get_lhs_offset(output_area.start_row(), shape.k * sizeof(int8_t)); + const auto imp_packed_lhs_offset = + variant.fn_pack_lhs_get_packed_lhs_offset(output_area.start_row(), shape.k, imp_mr, imp_kr, imp_sr); + + variant.fn_pack_lhs_run( + output_area.height(), shape.k, imp_mr, imp_kr, imp_sr, 0, lhs_qai8.data() + imp_lhs_offset, + shape.k * sizeof(int8_t), imp_packed_lhs.data() + imp_packed_lhs_offset); + + const auto imp_packed_lhs_end_offset = output_area.end_row() < shape.m + ? variant.fn_pack_lhs_get_packed_lhs_offset(output_area.end_row(), shape.k, imp_mr, imp_kr, imp_sr) + : imp_packed_lhs_size; + + for (size_t i = 0; i < ref_packed_lhs.size(); ++i) { + if (i >= imp_packed_lhs_offset && i < imp_packed_lhs_end_offset) { + ASSERT_EQ(imp_packed_lhs[i], ref_packed_lhs[i]); + } else { + ASSERT_EQ(imp_packed_lhs[i], 0); + } + } + } + + // Runs the optimized implementation of RHS packing. + const auto imp_packed_rhs_size = variant.fn_pack_rhs_get_packed_rhs_size(shape.n, shape.k); + ASSERT_EQ(imp_packed_rhs_size, ref_packed_rhs.size()); + std::vector imp_packed_rhs(imp_packed_rhs_size); + + { + const auto imp_rhs_offset = variant.fn_pack_rhs_get_rhs_offset(output_area.start_col()); + const auto imp_bias_offset = variant.fn_pack_rhs_get_bias_offset(output_area.start_col()); + const auto imp_scale_offset = variant.fn_pack_rhs_get_scale_offset(output_area.start_col()); + const auto imp_packed_rhs_offset = variant.fn_pack_rhs_get_packed_rhs_offset(output_area.start_col(), shape.k); + + const kai_rhs_pack_qsi8_params imp_pack_rhs_params{ + .lhs_zero_point = lhs_zero_point, + .scale_multiplier = lhs_scale / dst_scale, + }; + + variant.fn_pack_rhs_run( + 1, output_area.width(), shape.k, imp_nr, imp_kr, imp_sr, shape.n * sizeof(int8_t), + rhs_qsi8.data() + imp_rhs_offset, bias_qsi32.data() + imp_bias_offset, rhs_scales.data() + imp_scale_offset, + imp_packed_rhs.data() + imp_packed_rhs_offset, 0, &imp_pack_rhs_params); + + const auto imp_packed_rhs_end_offset = output_area.end_col() < shape.n + ? variant.fn_pack_rhs_get_packed_rhs_offset(output_area.end_col(), shape.k) + : imp_packed_rhs_size; + + for (size_t i = 0; i < ref_packed_rhs.size(); ++i) { + if (i >= imp_packed_rhs_offset && i < imp_packed_rhs_end_offset) { + ASSERT_EQ(imp_packed_rhs[i], ref_packed_rhs[i]); + } else { + ASSERT_EQ(imp_packed_rhs[i], 0); + } + } + } + + // Runs the optimized implementation of GEMM kernel. + const auto imp_dst_size = variant.fn_main_get_dst_size(shape.m, shape.n); + ASSERT_EQ(imp_dst_size, ref_dst_qsi8_clamped.size()); + + std::vector imp_dst(imp_dst_size); + + { + const auto imp_packed_lhs_offset = variant.fn_main_get_packed_lhs_offset(output_area.start_row(), shape.k); + const auto imp_packed_rhs_offset = variant.fn_main_get_packed_rhs_offset(output_area.start_col(), shape.k); + const auto imp_dst_offset = + variant.fn_main_get_dst_offset(output_area.start_row(), output_area.start_col(), shape.n * sizeof(int8_t)); + ASSERT_EQ(imp_dst_offset, output_area.start_row() * shape.n + output_area.start_col()); + + const kai_matmul_requantize32_params imp_main_params{ + .min_value = dst_qai8_clamp_min, + .max_value = dst_qai8_clamp_max, + .output_zero_point = dst_zero_point, + }; + + variant.fn_main_run( + output_area.height(), output_area.width(), shape.k, imp_packed_lhs.data() + imp_packed_lhs_offset, + imp_packed_rhs.data() + imp_packed_rhs_offset, imp_dst.data() + imp_dst_offset, shape.n * sizeof(int8_t), + sizeof(int8_t), &imp_main_params); + + for (size_t y = 0; y < shape.m; ++y) { + for (size_t x = 0; x < shape.n; ++x) { + const auto i = y * shape.n + x; + const auto in_area = y >= output_area.start_row() && y < output_area.end_row() && + x >= output_area.start_col() && x < output_area.end_col(); + + const int32_t imp_value = read_array(imp_dst.data(), i); + const int32_t ref_value = in_area ? read_array(ref_dst_qsi8_clamped.data(), i) : 0; + const auto error = std::abs(imp_value - ref_value); + const auto threshold = in_area ? 1 : 0; + + if (error > threshold) { + ASSERT_EQ(imp_value, ref_value); + } + } + } + } +} + +using ThisTest = testing::TestWithParam>; + +TEST_P(ThisTest, EndToEnd) { + const auto& [variant, shape, output_portion] = GetParam(); + + run_test(shape, variant, output_portion); +} + +} // namespace + +INSTANTIATE_TEST_SUITE_P( + matmul_clamp_qai8_qai8p_qsi8cxp, ThisTest, + testing::Combine( + testing::ValuesIn(gemm_variants), testing::ValuesIn(gemm_shapes), testing::ValuesIn(output_portions))); + +} // namespace kai::test