diff --git a/CHANGELOG.md b/CHANGELOG.md index 6a7896e92adb89e7a2d75a30f957ed3920cfa23d..b7fa829c61dd5f912f2187709ba358ea6262cb43 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,10 +14,14 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo - Matrix multiplication (MxN) Micro-kernels of QAI8DX LHS and QSI8CX RHS with F16 output, optimized for FEAT_I8MM and FEAT_DotProd. - Matrix multiplication (1xN) Micro-kernels of QAI8DX LHS and QSI8CX RHS with F16 output, optimized for FEAT_DotProd. - New SME micro-kernels: - - Indirect matrix multiplication (MxN) of FP16 input and output. + - Indirect matrix multiplication (MxN) of F16 input and output. + - Packing kernels for LHS and RHS + - Indirect matrix multiplication (MxN) of F32 input and output. - Packing kernels for LHS and RHS - New SME2 micro-kernels: - - Indirect matrix multiplication (MxN) of FP16 input and output. + - Indirect matrix multiplication (MxN) of F16 input and output. + - Matrix multiplication of packed indirect LHS and packed RHS + - Indirect matrix multiplication (MxN) of F32 input and output. - Matrix multiplication of packed indirect LHS and packed RHS - Disable link time optimization for microkernel library diff --git a/CMakeLists.txt b/CMakeLists.txt index f2221ecf17a783e673ca65d8758098a639c74789..ad8f5580a873f9eeb5cded48e45909ee7856d7ab 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -223,12 +223,14 @@ set(KLEIDIAI_FILES_NEON_I8MM set(KLEIDIAI_FILES_SME kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.c + kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.c kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme.c kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.c kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.c + kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c @@ -239,6 +241,7 @@ set(KLEIDIAI_FILES_SME set(KLEIDIAI_FILES_SME2 kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c + kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot.c kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 7408dfff16ef6f8fa1898942e7d61a011d333b11..2da2800c3604178bcba0b0c8cbcd9c7e9c3615a5 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -139,6 +139,7 @@ I8MM_KERNELS_ASM = [ # buildifier: keep sorted SME_KERNELS = [ "pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme", + "pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme", "pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme", "pack/kai_lhs_pack_bf16p2vlx2_f32_sme", "pack/kai_lhs_pack_f32p2vlx1_f32_sme", @@ -146,6 +147,7 @@ SME_KERNELS = [ "pack/kai_lhs_pack_x8p2vlx4_x8_sme", "pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme", "pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme", + "pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme", "pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme", "pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme", "pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme", @@ -158,6 +160,7 @@ SME_KERNELS = [ # buildifier: keep sorted SME2_KERNELS = [ "imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa", + "imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa", "imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa", "matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot", "matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa", diff --git a/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.c b/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.c new file mode 100644 index 0000000000000000000000000000000000000000..a927e2b7fd4cb4ec3badbdc4f6d33a022b16dec4 --- /dev/null +++ b/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.c @@ -0,0 +1,311 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +// Do not flag up inline assembly blocks +#pragma GCC diagnostic ignored "-Woverlength-strings" + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural features check. + +#include "kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.h" + +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_mr = 2; +static const size_t kai_nr = 2; +static const size_t kai_kr = 1; + +size_t kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(void) { + return kai_mr * kai_get_sme_vector_length_u32(); +} + +size_t kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(void) { + return kai_nr * kai_get_sme_vector_length_u32(); +} + +size_t kai_get_lhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa( + size_t m_idx, size_t k_chunk_count, size_t k_chunk_length) { + KAI_ASSUME(m_idx % kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa() == 0); + const size_t indirect_k = k_chunk_count * kai_roundup(k_chunk_length, kai_kr); + return m_idx * indirect_k * sizeof(float); +} + +static size_t kai_get_rhs_packed_stride_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa( + size_t k_chunk_count, size_t k_chunk_length) { + const size_t indirect_k = k_chunk_count * kai_roundup(k_chunk_length, kai_kr); + return kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa() * + (sizeof(float) + indirect_k * sizeof(float)); +} + +size_t kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa( + size_t n_idx, size_t k_chunk_count, size_t k_chunk_length) { + KAI_ASSUME(n_idx % kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa() == 0); + const size_t block_idx = n_idx / kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(); + return block_idx * + kai_get_rhs_packed_stride_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa( + k_chunk_count, k_chunk_length); +} + +size_t kai_get_dst_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa( + size_t m_idx, size_t n_idx, size_t dst_row_stride) { + KAI_ASSUME(m_idx % kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa() == 0); + KAI_ASSUME(n_idx % kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa() == 0); + + return m_idx * dst_row_stride + n_idx * sizeof(float); +} + +size_t kai_get_dst_size_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(size_t m, size_t n) { + return m * n * sizeof(float); +} + +void kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa( + size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, const void* lhs_packed, const void* rhs_packed, + void* dst, size_t dst_row_stride, float clamp_min, float clamp_max) { + typedef struct { + const void* A; + const void* B; + void* C; + uint64_t ldcb; + uint64_t M; + uint64_t N; + uint64_t K; + float min; + float max; + void* accumulator_buffer; + uint64_t flags; + } KernelArgs; + + KernelArgs args; + + args.A = lhs_packed; + args.B = rhs_packed; + + const size_t indirect_k = k_chunk_count * kai_roundup(k_chunk_length, kai_kr); + + args.C = dst; + args.ldcb = dst_row_stride; + args.M = m; + args.N = n; + args.K = indirect_k; + args.min = clamp_min; + args.max = clamp_max; + + args.accumulator_buffer = NULL; + args.flags = 0; + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "ldr w14, [%x[args], %[offsetof_M]]\n" + "mov x13, #0x0\n" + "mov x11, #0x0\n" + "ptrue p0.b\n" + ".inst 0x25207811 // ptrue pn9.b\n" + "ldr w10, [%x[args], %[offsetof_N]]\n" + "ldr x9, [%x[args], %[offsetof_A]]\n" + "1:" // M loop + "ldr x28, [%x[args], %[offsetof_B]]\n" + "2:" // N loop + ".inst 0x25aa4570 // whilelt pn8.s, x11, x10, VLx2\n" + "fmov z13.s, #1.0\n" + ".inst 0xc00800ff // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 }\n" + "mov x27, x9\n" + ".inst 0xa040438e // ld1w { z14.s-z15.s }, p8/Z, [x28]\n" // Load bias + "addvl x28, x28, #2\n" + ".inst 0x808e01a0 // fmopa za0.s, p0/M, p0/M, z13.s, z14.s\n" + ".inst 0x808f01a1 // fmopa za1.s, p0/M, p0/M, z13.s, z15.s\n" + ".inst 0x808e01a2 // fmopa za2.s, p0/M, p0/M, z13.s, z14.s\n" + ".inst 0x808f01a3 // fmopa za3.s, p0/M, p0/M, z13.s, z15.s\n" + "ldr x20, [%x[args], %[offsetof_K]]\n" + "lsr x21, x20, #0x2\n" + "and x20, x20, #0x3\n" + "cbz x21, 6f\n" + "subs x21, x21, #0x1\n" + ".inst 0xa1404772 // ld1w { z18.s, z26.s }, pn9.b/Z, [x27]\n" + ".inst 0xa0404794 // ld1w { z20.s-z21.s }, pn9.b/Z, [x28]\n" + ".inst 0xa1414764 // ld1w { z4.s, z12.s }, pn9.b/Z, [x27, #0x2, MUL VL]\n" + ".inst 0xa041478a // ld1w { z10.s-z11.s }, pn9.b/Z, [x28, #0x2, MUL VL]\n" + ".inst 0xa1424773 // ld1w { z19.s, z27.s }, pn9.b/Z, [x27, #0x4, MUL VL]\n" + ".inst 0xa0424798 // ld1w { z24.s-z25.s }, pn9.b/Z, [x28, #0x4, MUL VL]\n" + ".inst 0xa043476e // ld1w { z14.s-z15.s }, pn9.b/Z, [x27, #0x6, MUL VL]\n" + "addvl x27, x27, #8\n" + ".inst 0xa1434796 // ld1w { z22.s, z30.s }, pn9.b/Z, [x28, #0x6, MUL VL]\n" + "addvl x28, x28, #8\n" + "ble 5f\n" + "4:" // K loop + ".inst 0x80940240 // fmopa za0.s, p0/M, p0/M, z18.s, z20.s\n" + "subs x21, x21, #0x1\n" + ".inst 0x80950241 // fmopa za1.s, p0/M, p0/M, z18.s, z21.s\n" + ".inst 0x80940342 // fmopa za2.s, p0/M, p0/M, z26.s, z20.s\n" + ".inst 0x80950343 // fmopa za3.s, p0/M, p0/M, z26.s, z21.s\n" + ".inst 0xa1404772 // ld1w { z18.s, z26.s }, pn9.b/Z, [x27]\n" + ".inst 0x808a0080 // fmopa za0.s, p0/M, p0/M, z4.s, z10.s\n" + ".inst 0xa0404794 // ld1w { z20.s-z21.s }, pn9.b/Z, [x28]\n" + ".inst 0x808b0081 // fmopa za1.s, p0/M, p0/M, z4.s, z11.s\n" + ".inst 0x808a0182 // fmopa za2.s, p0/M, p0/M, z12.s, z10.s\n" + ".inst 0x808b0183 // fmopa za3.s, p0/M, p0/M, z12.s, z11.s\n" + ".inst 0xa1414764 // ld1w { z4.s, z12.s }, pn9.b/Z, [x27, #0x2, MUL VL]\n" + ".inst 0x80980260 // fmopa za0.s, p0/M, p0/M, z19.s, z24.s\n" + ".inst 0xa041478a // ld1w { z10.s-z11.s }, pn9.b/Z, [x28, #0x2, MUL VL]\n" + ".inst 0x80990261 // fmopa za1.s, p0/M, p0/M, z19.s, z25.s\n" + ".inst 0x80980362 // fmopa za2.s, p0/M, p0/M, z27.s, z24.s\n" + ".inst 0x80990363 // fmopa za3.s, p0/M, p0/M, z27.s, z25.s\n" + ".inst 0xa1424773 // ld1w { z19.s, z27.s }, pn9.b/Z, [x27, #0x4, MUL VL]\n" + ".inst 0xa0424798 // ld1w { z24.s-z25.s }, pn9.b/Z, [x28, #0x4, MUL VL]\n" + ".inst 0x809601c0 // fmopa za0.s, p0/M, p0/M, z14.s, z22.s\n" + ".inst 0x809e01c1 // fmopa za1.s, p0/M, p0/M, z14.s, z30.s\n" + ".inst 0x809601e2 // fmopa za2.s, p0/M, p0/M, z15.s, z22.s\n" + ".inst 0x809e01e3 // fmopa za3.s, p0/M, p0/M, z15.s, z30.s\n" + ".inst 0xa043476e // ld1w { z14.s-z15.s }, pn9.b/Z, [x27, #0x6, MUL VL]\n" + "addvl x27, x27, #8\n" + ".inst 0xa1434796 // ld1w { z22.s, z30.s }, pn9.b/Z, [x28, #0x6, MUL VL]\n" + "addvl x28, x28, #8\n" + "bgt 4b\n" + "5:" // K loop tail + ".inst 0x80940240 // fmopa za0.s, p0/M, p0/M, z18.s, z20.s\n" + ".inst 0x80950241 // fmopa za1.s, p0/M, p0/M, z18.s, z21.s\n" + ".inst 0x80940342 // fmopa za2.s, p0/M, p0/M, z26.s, z20.s\n" + ".inst 0x80950343 // fmopa za3.s, p0/M, p0/M, z26.s, z21.s\n" + ".inst 0x808a0080 // fmopa za0.s, p0/M, p0/M, z4.s, z10.s\n" + ".inst 0x808b0081 // fmopa za1.s, p0/M, p0/M, z4.s, z11.s\n" + ".inst 0x808a0182 // fmopa za2.s, p0/M, p0/M, z12.s, z10.s\n" + ".inst 0x808b0183 // fmopa za3.s, p0/M, p0/M, z12.s, z11.s\n" + ".inst 0x80980260 // fmopa za0.s, p0/M, p0/M, z19.s, z24.s\n" + ".inst 0x80990261 // fmopa za1.s, p0/M, p0/M, z19.s, z25.s\n" + ".inst 0x80980362 // fmopa za2.s, p0/M, p0/M, z27.s, z24.s\n" + ".inst 0x80990363 // fmopa za3.s, p0/M, p0/M, z27.s, z25.s\n" + ".inst 0x809601c0 // fmopa za0.s, p0/M, p0/M, z14.s, z22.s\n" + ".inst 0x809e01c1 // fmopa za1.s, p0/M, p0/M, z14.s, z30.s\n" + ".inst 0x809601e2 // fmopa za2.s, p0/M, p0/M, z15.s, z22.s\n" + ".inst 0x809e01e3 // fmopa za3.s, p0/M, p0/M, z15.s, z30.s\n" + "6:" // K oddments + "cbz x20, 8f\n" + "7:" // K oddments: Loop + ".inst 0xa040477c // ld1w { z28.s-z29.s }, pn9.b/Z, [x27]\n" + "subs x20, x20, #0x1\n" + "addvl x27, x27, #2\n" + ".inst 0xa1404787 // ld1w { z7.s, z15.s }, pn9.b/Z, [x28]\n" + "addvl x28, x28, #2\n" + ".inst 0x80870380 // fmopa za0.s, p0/M, p0/M, z28.s, z7.s\n" + ".inst 0x808f0381 // fmopa za1.s, p0/M, p0/M, z28.s, z15.s\n" + ".inst 0x808703a2 // fmopa za2.s, p0/M, p0/M, z29.s, z7.s\n" + ".inst 0x808f03a3 // fmopa za3.s, p0/M, p0/M, z29.s, z15.s\n" + "bgt 7b\n" + "8:" // K oddments: End + "ldr x26, [%x[args], %[offsetof_C]]\n" + "sub x25, x14, x13\n" + "cntw x24\n" + "ld1rw { z19.s }, p0/Z, [%x[args], %[offsetof_KernelArgs_min]]\n" + "ldr x23, [%x[args], %[offsetof_ldcb]]\n" + "cmp x25, x24\n" + "ld1rw { z26.s }, p0/Z, [%x[args], %[offsetof_KernelArgs_max]]\n" + "mov x12, #0x0\n" + "csel x22, x25, x24, LT\n" + "add x26, x26, x11, LSL #2\n" // C += n + "lsr x21, x22, #0x2\n" + "madd x26, x13, x23, x26\n" // C += m * ldc + "and x20, x22, #0x3\n" + "cbz x21, 11f\n" + "10:" // Store to output array: Accumulator row 0 loop + ".inst 0xc0860404 // mova { z4.s-z7.s }, za0h.s[x12]\n" + ".inst 0xc086042c // mova { z12.s-z15.s }, za1h.s[x12]\n" + ".inst 0xc1baca64 // fclamp { z4.s-z7.s }, z19.s, z26.s\n" + ".inst 0xc1baca6c // fclamp { z12.s-z15.s }, z19.s, z26.s\n" + "add x12, x12, #0x4\n" + "cmp x12, x21, LSL #2\n" + ".inst 0xa1604344 // st1w { z4.s, z12.s }, p8, [x26]\n" + "add x26, x26, x23\n" + ".inst 0xa1604345 // st1w { z5.s, z13.s }, p8, [x26]\n" + "add x26, x26, x23\n" + ".inst 0xa1604346 // st1w { z6.s, z14.s }, p8, [x26]\n" + "add x26, x26, x23\n" + ".inst 0xa1604347 // st1w { z7.s, z15.s }, p8, [x26]\n" + "add x26, x26, x23\n" + "blt 10b\n" + "11:" // Store to output array: Accumulator row 0 oddments + "cbz x20, 12f\n" + ".inst 0xc0860400 // mova { z0.s-z3.s }, za0h.s[x12]\n" + ".inst 0xc0860428 // mova { z8.s-z11.s }, za1h.s[x12]\n" + "subs x20, x20, #0x1\n" + ".inst 0xc1baca60 // fclamp { z0.s-z3.s }, z19.s, z26.s\n" + ".inst 0xc1baca68 // fclamp { z8.s-z11.s }, z19.s, z26.s\n" + ".inst 0xa1604340 // st1w { z0.s, z8.s }, p8, [x26]\n" + "add x26, x26, x23\n" + "beq 12f\n" + "subs x20, x20, #0x1\n" + ".inst 0xa1604341 // st1w { z1.s, z9.s }, p8, [x26]\n" + "add x26, x26, x23\n" + "beq 12f\n" + ".inst 0xa1604342 // st1w { z2.s, z10.s }, p8, [x26]\n" + "add x26, x26, x23\n" + "12:" // Store to output array: Accumulator row 0 oddments: End + "subs x25, x25, x22\n" + "beq 16f\n" + "cmp x25, x24\n" + "mov x12, #0x0\n" + "csel x20, x25, x24, LT\n" + "lsr x21, x20, #0x2\n" + "and x20, x20, #0x3\n" + "cbz x21, 14f\n" + "13:" // Store to output array: Accumulator row 1 loop + ".inst 0xc0860454 // mova { z20.s-z23.s }, za2h.s[x12]\n" + ".inst 0xc086047c // mova { z28.s-z31.s }, za3h.s[x12]\n" + ".inst 0xc1baca74 // fclamp { z20.s-z23.s }, z19.s, z26.s\n" + ".inst 0xc1baca7c // fclamp { z28.s-z31.s }, z19.s, z26.s\n" + "add x12, x12, #0x4\n" + "cmp x12, x21, LSL #2\n" + ".inst 0xa1604354 // st1w { z20.s, z28.s }, p8, [x26]\n" + "add x26, x26, x23\n" + ".inst 0xa1604355 // st1w { z21.s, z29.s }, p8, [x26]\n" + "add x26, x26, x23\n" + ".inst 0xa1604356 // st1w { z22.s, z30.s }, p8, [x26]\n" + "add x26, x26, x23\n" + ".inst 0xa1604357 // st1w { z23.s, z31.s }, p8, [x26]\n" + "add x26, x26, x23\n" + "blt 13b\n" + "14:" // Store to output array: Accumulator row 1 oddments + "cbz x20, 15f\n" + ".inst 0xc0860444 // mova { z4.s-z7.s }, za2h.s[x12]\n" + ".inst 0xc086046c // mova { z12.s-z15.s }, za3h.s[x12]\n" + "subs x20, x20, #0x1\n" + ".inst 0xc1baca64 // fclamp { z4.s-z7.s }, z19.s, z26.s\n" + ".inst 0xc1baca6c // fclamp { z12.s-z15.s }, z19.s, z26.s\n" + ".inst 0xa1604344 // st1w { z4.s, z12.s }, p8, [x26]\n" + "add x26, x26, x23\n" + "beq 15f\n" + "subs x20, x20, #0x1\n" + ".inst 0xa1604345 // st1w { z5.s, z13.s }, p8, [x26]\n" + "add x26, x26, x23\n" + "beq 15f\n" + ".inst 0xa1604346 // st1w { z6.s, z14.s }, p8, [x26]\n" + "15:" // Store to output array: Accumulator row 1 oddments: End + "16:" // Store to output array: End + "incw x11, ALL, MUL #2\n" + "cmp x11, x10\n" + "blt 2b\n" + "incw x13, ALL, MUL #2\n" + "mov x11, #0x0\n" + "cmp x13, x14\n" + "mov x9, x27\n" + "blt 1b\n" + ".inst 0xd503467f // SMSTOP\n" + : + : [args] "r"(&args), [offsetof_A] "I"(offsetof(KernelArgs, A)), [offsetof_B] "I"(offsetof(KernelArgs, B)), + [offsetof_C] "I"(offsetof(KernelArgs, C)), [offsetof_K] "I"(offsetof(KernelArgs, K)), + [offsetof_KernelArgs_max] "I"(offsetof(KernelArgs, max)), + [offsetof_KernelArgs_min] "I"(offsetof(KernelArgs, min)), [offsetof_M] "I"(offsetof(KernelArgs, M)), + [offsetof_N] "I"(offsetof(KernelArgs, N)), [offsetof_ldcb] "I"(offsetof(KernelArgs, ldcb)) + : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", "p7", + "p8", "p9", "x10", "x11", "x12", "x13", "x14", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", + "x9", "z0", "z1", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z2", "z20", "z21", + "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z3", "z30", "z31", "z4", "z5", "z6", "z7", "z8", + "z9"); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.h b/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.h new file mode 100644 index 0000000000000000000000000000000000000000..c7ac5fa1fb8cc5537e05616d1f4bb9fd674b73fd --- /dev/null +++ b/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.h @@ -0,0 +1,97 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Micro-kernel dependencies +/// +/// -# kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme to pack the LHS matrix. +/// -# kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme to pack the RHS matrix. + +/// Gets m step value. +/// +/// The starting row index must be divisible by `m_step`. +/// +/// @return The m step value. +size_t kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(void); + +/// Gets n step value. +/// +/// The starting column index must be divisible by `n_step`. +/// +/// @return The n step value. +size_t kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(void); + +/// Gets the offset in bytes to the data element in the packed LHS matrix buffer. +/// +/// @param[in] m_idx Row index in the unpacked LHS matrix. Must be a multiple of `m_step`. +/// @param[in] k_chunk_count Number of LHS column splits. +/// @param[in] k_chunk_length Length of a LHS column split. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_lhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa( + size_t m_idx, size_t k_chunk_count, size_t k_chunk_length); + +/// Gets the offset in bytes to the data element in the packed RHS matrix buffer. +/// +/// @param[in] n_idx Column index in the unpacked RHS matrix. Must be a multiple of `n_step`. +/// @param[in] k_chunk_count Number of LHS column splits. +/// @param[in] k_chunk_length Length of a LHS column split. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa( + size_t n_idx, size_t k_chunk_count, size_t k_chunk_length); + +/// Gets the offset in bytes to the data element in the destination matrix buffer. +/// +/// @param[in] m_idx Row index. Must be a multiple of `m_step`. +/// @param[in] n_idx Column index. Must be a multiple of `n_step`. +/// @param[in] dst_row_stride. Distance between start of two rows in the output buffer. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_dst_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa( + size_t m_idx, size_t n_idx, size_t dst_row_stride); + +/// Gets the size in bytes of the destination matrix buffer. +/// +/// @param[in] m Number of rows. +/// @param[in] n Number of columns. +/// +/// @return The size in bytes of the destination matrix buffer. +size_t kai_get_dst_size_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(size_t m, size_t n); + +/// Runs the matrix multiplication microkernel followed by a clamp operation. +/// +/// The pointer of each buffers (packed LHS, packed RHS and output) needs to be added with offset +/// calculated using the following functions: +/// +/// * Packed LHS: @ref kai_get_lhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa. +/// * Packed RHS: @ref kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa. +/// * Output: @ref kai_get_dst_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa. +/// +/// @param[in] m Number of output rows to be computed. +/// @param[in] n Number of output columns to be computed. +/// @param[in] k_chunk_count Number of LHS column splits. +/// @param[in] k_chunk_length Length of a LHS column split +/// @param[in] lhs_packed Packed LHS matrix buffer. +/// @param[in] rhs_packed Packed RHS matrix buffer. +/// @param[out] dst Output matrix buffer. +/// @param[in] dst_row_stride Row stride in bytes of the output matrix. +/// @param[in] clamp_min Minimum value to clamp the final result. +/// @param[in] clamp_max Maximum value to clamp the final result. +void kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa( + size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, const void* lhs_packed, const void* rhs_packed, + void* dst, size_t dst_row_stride, float clamp_min, float clamp_max); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p_f32p_interface.h b/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p_f32p_interface.h new file mode 100644 index 0000000000000000000000000000000000000000..6e629274e7da42d24f484f5e1c5f70dd2ba06cda --- /dev/null +++ b/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p_f32p_interface.h @@ -0,0 +1,45 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// All micro-kernels variants of the same type share the same interfaces +// In this case, the micro-kernel type is: imatmul_clamp_f32_f32p_f32p + +/// Micro-kernel helper functions ("get" methods) +typedef size_t (*kai_imatmul_clamp_f32_f32p_f32p_get_m_step_func_t)(void); +typedef size_t (*kai_imatmul_clamp_f32_f32p_f32p_get_n_step_func_t)(void); +typedef size_t (*kai_imatmul_clamp_f32_f32p_f32p_get_lhs_packed_offset_func_t)( + size_t m_idx, size_t k_chunk_count, size_t k_chunk_length); +typedef size_t (*kai_imatmul_clamp_f32_f32p_f32p_get_rhs_packed_offset_func_t)( + size_t n_idx, size_t k_chunk_count, size_t k_chunk_length); +typedef size_t (*kai_imatmul_clamp_f32_f32p_f32p_get_dst_offset_func_t)(size_t m_idx, size_t n_idx, size_t dst_stride); +typedef size_t (*kai_imatmul_clamp_f32_f32p_f32p_get_dst_size_func_t)(size_t m, size_t n); + +/// Micro-kernel core function ("run" method) +typedef void (*kai_imatmul_clamp_f32_f32p_f32p_run_imatmul_func_t)( + size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, const void* lhs_packed, const void* rhs_packed, + void* dst, size_t dst_row_stride, float clamp_min, float clamp_max); + +/// Micro-kernel interface +struct kai_imatmul_clamp_f32_f32p_f32p_ukernel { + kai_imatmul_clamp_f32_f32p_f32p_get_m_step_func_t get_m_step; + kai_imatmul_clamp_f32_f32p_f32p_get_n_step_func_t get_n_step; + kai_imatmul_clamp_f32_f32p_f32p_get_lhs_packed_offset_func_t get_lhs_packed_offset; + kai_imatmul_clamp_f32_f32p_f32p_get_rhs_packed_offset_func_t get_rhs_packed_offset; + kai_imatmul_clamp_f32_f32p_f32p_get_dst_offset_func_t get_dst_offset; + kai_imatmul_clamp_f32_f32p_f32p_get_dst_size_func_t get_dst_size; + kai_imatmul_clamp_f32_f32p_f32p_run_imatmul_func_t run_imatmul; +}; + +#ifdef __cplusplus +} +#endif diff --git a/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h index 21c7b52681a113a73a57e8350d212162e1fbd9e1..2f52001b7a2f7ff25392f3f85f37c18ba636511a 100644 --- a/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h +++ b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h @@ -84,8 +84,8 @@ size_t kai_get_dst_size_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme /// @param[in] n Number of output columns to be computed. /// @param[in] k_chunk_count Number of LHS column splits. /// @param[in] k_chunk_length Length of a LHS column split -/// @param[in] packed_lhs Packed LHS matrix buffer. -/// @param[in] packed_rhs Packed RHS matrix buffer. +/// @param[in] lhs_packed Packed LHS matrix buffer. +/// @param[in] rhs_packed Packed RHS matrix buffer. /// @param[out] dst Output matrix buffer. /// @param[in] dst_row_stride Row stride in bytes of the output matrix. diff --git a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.c b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.c new file mode 100644 index 0000000000000000000000000000000000000000..bef728fe02e8265aef20b94d99c72227dd19e4ea --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.c @@ -0,0 +1,328 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +// Do not flag up inline assembly blocks +#pragma GCC diagnostic ignored "-Woverlength-strings" + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural features check. + +#include "kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.h" + +#include +#include + +#include "kai/kai_common.h" + +#define MR 2 +#define KR 1 +#define MAX_M_STEP (MR * (KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(float)) / KR) + +static size_t kai_get_mr_lhs_imatmul_pack_x32p2vlx1_x32p_sme(void) { + return MR * kai_get_sme_vector_length_u32() / KR; +} + +size_t kai_get_m_step_lhs_imatmul_pack_x32p2vlx1_x32p_sme(void) { + return kai_get_mr_lhs_imatmul_pack_x32p2vlx1_x32p_sme(); +} + +size_t kai_get_lhs_packed_offset_lhs_imatmul_pack_x32p2vlx1_x32p_sme( + size_t m_idx, size_t k_chunk_count, size_t k_chunk_length) { + KAI_ASSUME(m_idx % kai_get_m_step_lhs_imatmul_pack_x32p2vlx1_x32p_sme() == 0); + + return m_idx * k_chunk_count * kai_roundup(k_chunk_length, KR) * sizeof(float); +} + +size_t kai_get_lhs_packed_size_lhs_imatmul_pack_x32p2vlx1_x32p_sme( + size_t m, size_t k_chunk_count, size_t k_chunk_length) { + const size_t m_end = kai_roundup(m, kai_get_mr_lhs_imatmul_pack_x32p2vlx1_x32p_sme()); + return kai_get_lhs_packed_offset_lhs_imatmul_pack_x32p2vlx1_x32p_sme(m_end, k_chunk_count, k_chunk_length); +} + +void kai_run_lhs_imatmul_pack_x32p2vlx1_x32p_sme( + size_t m, size_t k_chunk_count, size_t k_chunk_length, const void* const* lhs_ptrs, size_t lhs_ptr_offset, + const void* pad_ptr, void* lhs_packed) { + KAI_ASSUME(lhs_ptrs != NULL); + KAI_ASSUME(lhs_packed != NULL); + + const size_t m_step = kai_get_mr_lhs_imatmul_pack_x32p2vlx1_x32p_sme(); + const size_t row_offset = 0; + const size_t width = k_chunk_length; + + KAI_ASSERT(m_step <= MAX_M_STEP); + const uint8_t* in[MAX_M_STEP]; + + uint8_t* out_base = lhs_packed; + for (size_t i_m = 0; i_m < m; i_m += m_step) { + for (size_t i_k_chunk = 0; i_k_chunk < k_chunk_count; i_k_chunk += 1) { + const size_t height = KAI_MIN(m - i_m, m_step); + void* out = out_base; + for (size_t y = 0; y < height; y += 1) { + KAI_ASSERT(i_k_chunk + (i_m + y) * k_chunk_count < m * k_chunk_count); + in[y] = *(lhs_ptrs + i_m * k_chunk_count + i_k_chunk * m_step + y); + if (in[y] != pad_ptr) { + in[y] += lhs_ptr_offset; + } + } + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "mov x21, %x[width]\n" + "mov x20, %x[width]\n" + "incw x21\n" + "cntw x17\n" + "sub x21, x21, #0x1\n" + "sub x16, x17, #0x1\n" + "udiv x21, x21, x17\n" // n_passes = ceildiv(width, VL) + "ands x16, x20, x16\n" + "sub x20, x21, #0x1\n" + "sub x15, x17, #0x2\n" + "mov x14, #0x0\n" + "mov x11, %x[in]\n" + "add x10, %x[in], x17, LSL #3\n" + "cntw x9, ALL, MUL #2\n" + "ldr x28, [x11, #0x0]\n" + "cntw x27, ALL, MUL #3\n" + "lsr x20, x20, #0x1\n" // n_loops = (n_passes - 1) / 2 + "ldr x26, [x10, #0x0]\n" + "and x25, x21, #0x1\n" // odd_tail = bool(n_passes & 0x1) + "csel x16, x16, x17, NE\n" + "ldr x24, [x11, #0x8]\n" + "ptrue p12.s\n" + "whilelt p11.s, XZR, %x[height]\n" + "ldr x21, [x10, #0x8]\n" + "whilelt p10.s, x17, %x[height]\n" + "mov x23, %x[row_offset]\n" + "mov x22, %x[out]\n" + "whilelt p9.s, x14, %x[width]\n" + "whilelt p8.s, x14, %x[width]\n" + "add x11, x11, #0x10\n" + "add x10, x10, #0x10\n" + "mov x12, #0x0\n" + "cbz x15, 2f\n" + "1:" // K loop: Charge: Loop + ".inst 0x25306163 // psel p3.s, p8.s/Z, p11.s[w12]\n" + ".inst 0x25306142 // psel p2.s, p8.s/Z, p10.s[w12]\n" + ".inst 0x25706161 // psel p1.s, p8.s/Z, p11.s[w12, #1]\n" + ".inst 0x25706140 // psel p0.s, p8.s/Z, p10.s[w12, #1]\n" + ".inst 0xe0970f80 // ld1w { za0h.s[x12] }, p3/Z, [x28, x23, LSL #2]\n" + "ldr x28, [x11, #0x0]\n" + ".inst 0xe0970b44 // ld1w { za1h.s[x12] }, p2/Z, [x26, x23, LSL #2]\n" + "ldr x26, [x10, #0x0]\n" + ".inst 0xe0970701 // ld1w { za0h.s[x12, #1] }, p1/Z, [x24, x23, LSL #2]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe09702a5 // ld1w { za1h.s[x12, #1] }, p0/Z, [x21, x23, LSL #2]\n" + "add x12, x12, #0x2\n" + "ldr x21, [x10, #0x8]\n" + "add x10, x10, #0x10\n" + "cmp x12, x15\n" + "blt 1b\n" + "2:" // K loop: Charge: End + ".inst 0x25306163 // psel p3.s, p8.s/Z, p11.s[w12]\n" + ".inst 0x25306142 // psel p2.s, p8.s/Z, p10.s[w12]\n" + ".inst 0x25706161 // psel p1.s, p8.s/Z, p11.s[w12, #1]\n" + ".inst 0x25706140 // psel p0.s, p8.s/Z, p10.s[w12, #1]\n" + "mov x11, %x[in]\n" + "add x10, %x[in], x17, LSL #3\n" + ".inst 0xe0970f80 // ld1w { za0h.s[x12] }, p3/Z, [x28, x23, LSL #2]\n" + "ldr x28, [x11, #0x0]\n" + "incw x14\n" + ".inst 0xe0970b44 // ld1w { za1h.s[x12] }, p2/Z, [x26, x23, LSL #2]\n" + "ldr x26, [x10, #0x0]\n" + ".inst 0xe0970701 // ld1w { za0h.s[x12, #1] }, p1/Z, [x24, x23, LSL #2]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe09702a5 // ld1w { za1h.s[x12, #1] }, p0/Z, [x21, x23, LSL #2]\n" + "ldr x21, [x10, #0x8]\n" + "add x10, x10, #0x10\n" + "incw x23\n" + "cbz x20, 8f\n" + "mov x20, x20\n" + "3:" // K loop: Main loop + "whilelt p8.s, x14, %x[width]\n" + "mov x13, #0x0\n" + "cbz x15, 5f\n" + "4:" // K loop: Main loop: First: Loop + ".inst 0x25316160 // psel p0.s, p8.s/Z, p11.s[w13]\n" + ".inst 0x25316142 // psel p2.s, p8.s/Z, p10.s[w13]\n" + ".inst 0x25716161 // psel p1.s, p8.s/Z, p11.s[w13, #1]\n" + ".inst 0x25716143 // psel p3.s, p8.s/Z, p10.s[w13, #1]\n" + ".inst 0xe0972388 // ld1w { za2h.s[x13] }, p0/Z, [x28, x23, LSL #2]\n" + ".inst 0x25317120 // psel p0.s, p12.s/Z, p9.s[w13]\n" + "ldr x28, [x11, #0x0]\n" + ".inst 0xe0972b4c // ld1w { za3h.s[x13] }, p2/Z, [x26, x23, LSL #2]\n" + ".inst 0x25317122 // psel p2.s, p12.s/Z, p9.s[w13]\n" + "ldr x26, [x10, #0x0]\n" + ".inst 0xe0972709 // ld1w { za2h.s[x13, #1] }, p1/Z, [x24, x23, LSL #2]\n" + ".inst 0x25717121 // psel p1.s, p12.s/Z, p9.s[w13, #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe0972ead // ld1w { za3h.s[x13, #1] }, p3/Z, [x21, x23, LSL #2]\n" + "ldr x21, [x10, #0x8]\n" + ".inst 0xe0bfa2c0 // st1w { za0v.s[x13] }, p0/Z, [x22, XZR, LSL #2]\n" + ".inst 0x25717120 // psel p0.s, p12.s/Z, p9.s[w13, #1]\n" + ".inst 0xe0b1aac4 // st1w { za1v.s[x13] }, p2/Z, [x22, x17, LSL #2]\n" + "add x10, x10, #0x10\n" + ".inst 0xe0a9a6c1 // st1w { za0v.s[x13, #1] }, p1/Z, [x22, x9, LSL #2]\n" + ".inst 0xe0bba2c5 // st1w { za1v.s[x13, #1] }, p0/Z, [x22, x27, LSL #2]\n" + "add x13, x13, #0x2\n" + "addvl x22, x22, #4\n" + "cmp x13, x15\n" + "blt 4b\n" + "5:" // K loop: Main loop: First: Tail + ".inst 0x25316160 // psel p0.s, p8.s/Z, p11.s[w13]\n" + ".inst 0x25316142 // psel p2.s, p8.s/Z, p10.s[w13]\n" + ".inst 0x25716161 // psel p1.s, p8.s/Z, p11.s[w13, #1]\n" + ".inst 0x25716143 // psel p3.s, p8.s/Z, p10.s[w13, #1]\n" + "mov x11, %x[in]\n" + "add x10, %x[in], x17, LSL #3\n" + ".inst 0xe0972388 // ld1w { za2h.s[x13] }, p0/Z, [x28, x23, LSL #2]\n" + ".inst 0x25317120 // psel p0.s, p12.s/Z, p9.s[w13]\n" + "ldr x28, [x11, #0x0]\n" + "mov x12, #0x0\n" + ".inst 0xe0972b4c // ld1w { za3h.s[x13] }, p2/Z, [x26, x23, LSL #2]\n" + ".inst 0x25317122 // psel p2.s, p12.s/Z, p9.s[w13]\n" + "ldr x26, [x10, #0x0]\n" + ".inst 0xe0972709 // ld1w { za2h.s[x13, #1] }, p1/Z, [x24, x23, LSL #2]\n" + ".inst 0x25717121 // psel p1.s, p12.s/Z, p9.s[w13, #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe0972ead // ld1w { za3h.s[x13, #1] }, p3/Z, [x21, x23, LSL #2]\n" + "ldr x21, [x10, #0x8]\n" + ".inst 0xe0bfa2c0 // st1w { za0v.s[x13] }, p0/Z, [x22, XZR, LSL #2]\n" + ".inst 0x25717120 // psel p0.s, p12.s/Z, p9.s[w13, #1]\n" + ".inst 0xe0b1aac4 // st1w { za1v.s[x13] }, p2/Z, [x22, x17, LSL #2]\n" + "whilelt p9.s, x14, %x[width]\n" + "incw x14\n" + ".inst 0xe0a9a6c1 // st1w { za0v.s[x13, #1] }, p1/Z, [x22, x9, LSL #2]\n" + "add x10, x10, #0x10\n" + "incw x23\n" + ".inst 0xe0bba2c5 // st1w { za1v.s[x13, #1] }, p0/Z, [x22, x27, LSL #2]\n" + "addvl x22, x22, #4\n" + "whilelt p8.s, x14, %x[width]\n" + "cbz x15, 7f\n" + "6:" // K loop: Main loop: Second: Loop + ".inst 0x25306160 // psel p0.s, p8.s/Z, p11.s[w12]\n" + ".inst 0x25306142 // psel p2.s, p8.s/Z, p10.s[w12]\n" + ".inst 0x25706161 // psel p1.s, p8.s/Z, p11.s[w12, #1]\n" + ".inst 0x25706143 // psel p3.s, p8.s/Z, p10.s[w12, #1]\n" + ".inst 0xe0970380 // ld1w { za0h.s[x12] }, p0/Z, [x28, x23, LSL #2]\n" + ".inst 0x25307120 // psel p0.s, p12.s/Z, p9.s[w12]\n" + "ldr x28, [x11, #0x0]\n" + ".inst 0xe0970b44 // ld1w { za1h.s[x12] }, p2/Z, [x26, x23, LSL #2]\n" + ".inst 0x25307122 // psel p2.s, p12.s/Z, p9.s[w12]\n" + "ldr x26, [x10, #0x0]\n" + ".inst 0xe0970701 // ld1w { za0h.s[x12, #1] }, p1/Z, [x24, x23, LSL #2]\n" + ".inst 0x25707121 // psel p1.s, p12.s/Z, p9.s[w12, #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe0970ea5 // ld1w { za1h.s[x12, #1] }, p3/Z, [x21, x23, LSL #2]\n" + "ldr x21, [x10, #0x8]\n" + ".inst 0xe0bf82c8 // st1w { za2v.s[x12] }, p0/Z, [x22, XZR, LSL #2]\n" + ".inst 0x25707120 // psel p0.s, p12.s/Z, p9.s[w12, #1]\n" + ".inst 0xe0b18acc // st1w { za3v.s[x12] }, p2/Z, [x22, x17, LSL #2]\n" + "add x10, x10, #0x10\n" + ".inst 0xe0a986c9 // st1w { za2v.s[x12, #1] }, p1/Z, [x22, x9, LSL #2]\n" + ".inst 0xe0bb82cd // st1w { za3v.s[x12, #1] }, p0/Z, [x22, x27, LSL #2]\n" + "add x12, x12, #0x2\n" + "addvl x22, x22, #4\n" + "cmp x12, x15\n" + "blt 6b\n" + "7:" // K loop: Main loop: Second: Tail + ".inst 0x25306160 // psel p0.s, p8.s/Z, p11.s[w12]\n" + ".inst 0x25306142 // psel p2.s, p8.s/Z, p10.s[w12]\n" + ".inst 0x25706161 // psel p1.s, p8.s/Z, p11.s[w12, #1]\n" + ".inst 0x25706143 // psel p3.s, p8.s/Z, p10.s[w12, #1]\n" + "mov x11, %x[in]\n" + "add x10, %x[in], x17, LSL #3\n" + ".inst 0xe0970380 // ld1w { za0h.s[x12] }, p0/Z, [x28, x23, LSL #2]\n" + ".inst 0x25307120 // psel p0.s, p12.s/Z, p9.s[w12]\n" + "ldr x28, [x11, #0x0]\n" + ".inst 0xe0970b44 // ld1w { za1h.s[x12] }, p2/Z, [x26, x23, LSL #2]\n" + ".inst 0x25307122 // psel p2.s, p12.s/Z, p9.s[w12]\n" + "ldr x26, [x10, #0x0]\n" + ".inst 0xe0970701 // ld1w { za0h.s[x12, #1] }, p1/Z, [x24, x23, LSL #2]\n" + ".inst 0x25707121 // psel p1.s, p12.s/Z, p9.s[w12, #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe0970ea5 // ld1w { za1h.s[x12, #1] }, p3/Z, [x21, x23, LSL #2]\n" + "ldr x21, [x10, #0x8]\n" + ".inst 0xe0bf82c8 // st1w { za2v.s[x12] }, p0/Z, [x22, XZR, LSL #2]\n" + ".inst 0x25707120 // psel p0.s, p12.s/Z, p9.s[w12, #1]\n" + ".inst 0xe0b18acc // st1w { za3v.s[x12] }, p2/Z, [x22, x17, LSL #2]\n" + "whilelt p9.s, x14, %x[width]\n" + "subs x20, x20, #0x1\n" + ".inst 0xe0a986c9 // st1w { za2v.s[x12, #1] }, p1/Z, [x22, x9, LSL #2]\n" + "add x10, x10, #0x10\n" + "incw x14\n" + ".inst 0xe0bb82cd // st1w { za3v.s[x12, #1] }, p0/Z, [x22, x27, LSL #2]\n" + "addvl x22, x22, #4\n" + "incw x23\n" + "bgt 3b\n" + "8:" // K loop: Tails + "cbnz x25, 11f\n" + "mov x11, %x[in]\n" + "whilelt p8.s, x14, %x[width]\n" + "mov x12, #0x0\n" + "9:" // K loop: Tails: Even: First + ".inst 0x25307123 // psel p3.s, p12.s/Z, p9.s[w12]\n" + ".inst 0x25307122 // psel p2.s, p12.s/Z, p9.s[w12]\n" + ".inst 0x25306161 // psel p1.s, p8.s/Z, p11.s[w12]\n" + ".inst 0x25306140 // psel p0.s, p8.s/Z, p10.s[w12]\n" + ".inst 0xe0bf8ec0 // st1w { za0v.s[x12] }, p3/Z, [x22, XZR, LSL #2]\n" + ".inst 0xe0b18ac4 // st1w { za1v.s[x12] }, p2/Z, [x22, x17, LSL #2]\n" + "addvl x22, x22, #2\n" + "ldr x21, [x11, #0x0]\n" + "ldr x20, [x11, x17, LSL #0x3]\n" + "add x11, x11, #0x8\n" + ".inst 0xe09706a8 // ld1w { za2h.s[x12] }, p1/Z, [x21, x23, LSL #2]\n" + ".inst 0xe097028c // ld1w { za3h.s[x12] }, p0/Z, [x20, x23, LSL #2]\n" + "add x12, x12, #0x1\n" + "cmp x12, x17\n" + "blt 9b\n" + "whilelt p9.s, x14, %x[width]\n" + "whilelt p8.s, x14, %x[width]\n" + "mov x12, #0x0\n" + "10:" // K loop: Tails: Even: Second + ".inst 0x25307121 // psel p1.s, p12.s/Z, p9.s[w12]\n" + ".inst 0x25307120 // psel p0.s, p12.s/Z, p9.s[w12]\n" + ".inst 0xe0bf86c8 // st1w { za2v.s[x12] }, p1/Z, [x22, XZR, LSL #2]\n" + ".inst 0xe0b182cc // st1w { za3v.s[x12] }, p0/Z, [x22, x17, LSL #2]\n" + "add x12, x12, #0x1\n" + "addvl x22, x22, #2\n" + "cmp x12, x16\n" + "blt 10b\n" + "whilelt p8.s, x14, %x[width]\n" + "b 13f\n" + "11:" // K loop: Tails: Odd + "mov x12, #0x0\n" + "12:" // K loop: Tails: Odd: Loop + ".inst 0x25307121 // psel p1.s, p12.s/Z, p9.s[w12]\n" + ".inst 0x25307120 // psel p0.s, p12.s/Z, p9.s[w12]\n" + ".inst 0xe0bf86c0 // st1w { za0v.s[x12] }, p1/Z, [x22, XZR, LSL #2]\n" + ".inst 0xe0b182c4 // st1w { za1v.s[x12] }, p0/Z, [x22, x17, LSL #2]\n" + "add x12, x12, #0x1\n" + "addvl x22, x22, #2\n" + "cmp x12, x16\n" + "blt 12b\n" + "13:" // K loop: End + "mov %x[out], x22\n" + ".inst 0xd503467f // SMSTOP\n" + : [out] "+&r"(out) + : [height] "r"(height), [in] "r"(in), [row_offset] "r"(row_offset), [width] "r"(width) + : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", + "p7", "p8", "p9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x20", "x21", "x22", "x23", + "x24", "x25", "x26", "x27", "x28", "x9", "z0", "z1", "z10", "z11", "z12", "z13", "z14", "z15", "z16", + "z17", "z18", "z19", "z2", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z3", + "z30", "z31", "z4", "z5", "z6", "z7", "z8", "z9"); + out_base += m_step * kai_roundup(k_chunk_length, KR) * sizeof(float); + } + } +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.h b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.h new file mode 100644 index 0000000000000000000000000000000000000000..5f6c68a945a969a23f39f392444ac41702e1e897 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.h @@ -0,0 +1,61 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Gets m step value. +/// +/// The starting row index must be divisible by `m_step`. +/// +/// @return Step size for row index +size_t kai_get_m_step_lhs_imatmul_pack_x32p2vlx1_x32p_sme(void); + +/// Gets the offset in bytes to the data element in the packed LHS buffer. +/// +/// @param[in] m_idx Row index in the unpacked LHS matrix. +/// @param[in] k_chunk_count Number of LHS column splits. +/// @param[in] k_chunk_length Length of a LHS column split. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_lhs_packed_offset_lhs_imatmul_pack_x32p2vlx1_x32p_sme( + size_t m_idx, size_t k_chunk_count, size_t k_chunk_length); + +/// Gets the size in bytes of the packed LHS buffer. +/// +/// @param[in] m Number of rows in the unpacked LHS matrix. +/// @param[in] k_chunk_count Number of LHS column splits. +/// @param[in] k_chunk_length Length of a LHS column split. +/// +/// @return The size in bytes of the packed LHS buffer. +size_t kai_get_lhs_packed_size_lhs_imatmul_pack_x32p2vlx1_x32p_sme( + size_t m, size_t k_chunk_count, size_t k_chunk_length); + +/// Pack the LHS matrix for use with indirect matrix multiplication +/// +/// @param[in] m Number of rows of the unpacked LHS matrix. +/// @param[in] k_chunk_count Number of LHS column splits. +/// @param[in] k_chunk_length Length of a LHS column split. +/// @param[in] lhs_ptrs Pointer to an array of input pointers consisting of +/// `m * k_chunk_count` pointers. +/// @param[in] lhs_ptr_offset Offset to add to each pointer of the @ref lhs_ptrs +/// array, excluding zero pointers. +/// @param[in] pad_ptr Pointer to chunk used for padding. @ref lhs_ptr_offset is +/// not applied to this pointer when used in @ref lhs_ptrs. This can +/// be NULL if there is no padding used @ref lhs_ptrs +/// @param[out] lhs_packed Packed LHS matrix. +void kai_run_lhs_imatmul_pack_x32p2vlx1_x32p_sme( + size_t m, size_t k_chunk_count, size_t k_chunk_length, const void* const* lhs_ptrs, size_t lhs_ptr_offset, + const void* pad_ptr, void* lhs_packed); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.c b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.c new file mode 100644 index 0000000000000000000000000000000000000000..46c626f1260af5fe2769bc4069f15ca726c524bd --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.c @@ -0,0 +1,182 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +// Do not flag up inline assembly blocks +#pragma GCC diagnostic ignored "-Woverlength-strings" + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural features check. +#include "kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.h" + +#include +#include + +#include "kai/kai_common.h" + +#define NR 2 +#define KR 1 +static const size_t kai_num_bytes_input = sizeof(uint32_t); +static const size_t kai_num_bytes_output = sizeof(uint32_t); +static const size_t kai_num_bytes_bias = sizeof(uint32_t); + +size_t kai_get_n_step_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(void) { + return NR * kai_get_sme_vector_length_u32() / KR; +} + +size_t kai_get_rhs_offset_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(size_t n_idx) { + KAI_ASSUME(n_idx % kai_get_n_step_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme() == 0); + + return n_idx * kai_num_bytes_input; +} + +size_t kai_get_bias_offset_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(size_t n_idx) { + return n_idx * kai_num_bytes_bias; +} + +static size_t kai_get_rhs_packed_stride_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme( + size_t k_chunk_count, size_t k_chunk_length) { + return kai_get_n_step_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme() * + (kai_num_bytes_bias + k_chunk_count * kai_roundup(k_chunk_length, KR) * kai_num_bytes_output); +} + +size_t kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme( + size_t n_idx, size_t k_chunk_count, size_t k_chunk_length) { + KAI_ASSUME(n_idx % kai_get_n_step_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme() == 0); + + const size_t block_idx = n_idx / kai_get_n_step_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(); + return block_idx * + kai_get_rhs_packed_stride_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(k_chunk_count, k_chunk_length); +} + +size_t kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme( + size_t n, size_t k_chunk_count, size_t k_chunk_length) { + const size_t n_rounded_up = kai_roundup(n, kai_get_n_step_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme()); + return kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme( + n_rounded_up, k_chunk_count, k_chunk_length); +} + +void kai_run_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme( + size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_row_stride, const void* rhs, const void* bias, + void* rhs_packed) { + KAI_ASSUME(rhs != NULL); + KAI_ASSUME(bias != NULL); + KAI_ASSUME(rhs_packed != NULL); + + size_t height = k_chunk_length; + const size_t width = n; + const void* in = rhs; + void* out = rhs_packed; + const size_t in_stride = rhs_row_stride; + + size_t out_stride = + kai_get_rhs_packed_stride_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(k_chunk_count, k_chunk_length); + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "mov x22, %x[out]\n" + "mov x21, %x[width]\n" + "ptrue p2.b\n" + "1:" // Bias: Full loop + "mov x20, x21\n" + "decw x21, ALL, MUL #2\n" + "whilelt p1.s, XZR, x20\n" + "decw x20\n" + "whilelt p0.s, XZR, x20\n" + "cmp x21, #0x0\n" + "ld1w { z17.s }, p1/Z, [%x[bias]]\n" + "ld1w { z16.s }, p0/Z, [%x[bias], #1, MUL VL]\n" + "incb %x[bias], ALL, MUL #2\n" + "st1w { z17.s }, p2, [x22]\n" + "st1w { z16.s }, p2, [x22, #1, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bgt 1b\n" + "incb %x[out], ALL, MUL #2\n" + "mov x28, %x[k_chunk_count]\n" + "2:" // Chunk Loop + "mov x27, %x[height]\n" + "cmp x27, #0x4\n" + "blt 6f\n" + "3:" // Main row loop: Head + "mov x26, %x[in]\n" + "mov x25, %x[out]\n" + "add x24, x26, %x[in_stride]\n" + "sub x27, x27, #0x4\n" + "add x23, x24, %x[in_stride]\n" + "mov x22, %x[width]\n" + "add x21, x23, %x[in_stride]\n" + "add %x[in], x21, %x[in_stride]\n" + "4:" // Main row loop: Column loop + "mov x20, x22\n" + "decw x22, ALL, MUL #2\n" + "whilelt p1.s, XZR, x20\n" + "decw x20\n" + "whilelt p0.s, XZR, x20\n" + "cmp x22, #0x0\n" + "ld1w { z23.s }, p1/Z, [x26]\n" + "ld1w { z22.s }, p0/Z, [x26, #1, MUL VL]\n" + "addvl x26, x26, #2\n" + "ld1w { z21.s }, p1/Z, [x24]\n" + "ld1w { z20.s }, p0/Z, [x24, #1, MUL VL]\n" + "addvl x24, x24, #2\n" + "ld1w { z19.s }, p1/Z, [x23]\n" + "ld1w { z18.s }, p0/Z, [x23, #1, MUL VL]\n" + "addvl x23, x23, #2\n" + "ld1w { z17.s }, p1/Z, [x21]\n" + "ld1w { z16.s }, p0/Z, [x21, #1, MUL VL]\n" + "addvl x21, x21, #2\n" + "st1w { z23.s }, p2, [x25]\n" + "st1w { z22.s }, p2, [x25, #1, MUL VL]\n" + "st1w { z21.s }, p2, [x25, #2, MUL VL]\n" + "st1w { z20.s }, p2, [x25, #3, MUL VL]\n" + "st1w { z19.s }, p2, [x25, #4, MUL VL]\n" + "st1w { z18.s }, p2, [x25, #5, MUL VL]\n" + "st1w { z17.s }, p2, [x25, #6, MUL VL]\n" + "st1w { z16.s }, p2, [x25, #7, MUL VL]\n" + "add x25, x25, %x[out_stride]\n" + "bgt 4b\n" + "cmp x27, #0x4\n" + "addvl %x[out], %x[out], #8\n" + "bge 3b\n" + "cbz x27, 10f\n" + "6:" // Main loop skip + "7:" // Tail row loop: Head + "mov x26, %x[in]\n" + "cntw x22, ALL, MUL #8\n" + "add %x[in], x26, %x[in_stride]\n" + "mov x25, %x[out]\n" + "sub x27, x27, #0x1\n" + "mov x21, %x[width]\n" + "8:" // Tail row loop: Column loop + "mov x20, x21\n" + "decw x21, ALL, MUL #2\n" + "whilelt p1.s, XZR, x20\n" + "decw x20\n" + "whilelt p0.s, XZR, x20\n" + "cmp x21, #0x0\n" + "ld1w { z17.s }, p1/Z, [x26]\n" + "ld1w { z16.s }, p0/Z, [x26, #1, MUL VL]\n" + "add x26, x26, x22\n" + "st1w { z17.s }, p2, [x25]\n" + "st1w { z16.s }, p2, [x25, #1, MUL VL]\n" + "add x25, x25, %x[out_stride]\n" + "bgt 8b\n" + "cmp x27, #0x1\n" + "addvl %x[out], %x[out], #2\n" + "bge 7b\n" + "10:" // Done + "sub x28, x28, #0x1\n" + "cbnz x28, 2b\n" + ".inst 0xd503467f // SMSTOP\n" + : [bias] "+&r"(bias), [in] "+&r"(in), [out] "+&r"(out) + : [height] "r"(height), [in_stride] "r"(in_stride), [k_chunk_count] "r"(k_chunk_count), + [out_stride] "r"(out_stride), [width] "r"(width) + : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", "p7", + "p8", "p9", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z10", "z11", "z12", + "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z2", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", + "z28", "z29", "z3", "z30", "z31", "z4", "z5", "z6", "z7", "z8", "z9"); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.h b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.h new file mode 100644 index 0000000000000000000000000000000000000000..ea16c9df140741431697d6579970596c2ed51b38 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.h @@ -0,0 +1,78 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Gets n step value. +/// +/// The starting column index must be divisible by `n_step`. +/// +/// @return Step size for column index. +size_t kai_get_n_step_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(void); + +/// Gets the offset in bytes to the data element in the RHS matrix buffer. +/// +/// @param[in] n_idx Column index. Must be divisible by `n_step` +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_offset_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(size_t n_idx); + +/// Gets the offset in bytes to the data element in the bias buffer. +/// +/// @param[in] n_idx Column index. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_bias_offset_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(size_t n_idx); + +/// Gets the offset in bytes to the data element in the packed RHS buffer. +/// +/// @param[in] n_idx Column index. Must be divisible by `n_step` +/// @param[in] k_chunk_count Number of chunks. +/// @param[in] k_chunk_length Number of rows in each chunk. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme( + size_t n_idx, size_t k_chunk_count, size_t k_chunk_length); + +/// Gets the size in bytes of the packed RHS buffer. +/// +/// @param[in] n Number of columns. +/// @param[in] k_chunk_count Number of chunks. +/// @param[in] k_chunk_length Number of rows in each chunk. +/// +/// @return The size in bytes of the packed RHS buffer. +size_t kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme( + size_t n, size_t k_chunk_count, size_t k_chunk_length); + +/// Runs the RHS packing function for matrix multiplication. +/// +/// The pointer of each buffers (RHS, bias and packed RHS) needs to be added with offset +/// calculated using the following functions: +/// +/// * RHS: @ref kai_get_rhs_offset_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme. +/// * Bias: @ref kai_get_bias_offset_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme. +/// * Output: @ref kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme. +/// +/// @param[in] n Number of columns of the output matrix. +/// @param[in] k_chunk_count Number of chunks. +/// @param[in] k_chunk_length Number of rows in each chunk. +/// @param[in] rhs_row_stride Row stride in bytes of the RHS matrix. +/// @param[in] rhs RHS matrix data buffer. +/// @param[in] bias Bias matrix data buffer. +/// @param[out] rhs_packed Packed RHS matrix. +void kai_run_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme( + size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_row_stride, const void* rhs, const void* bias, + void* rhs_packed); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/test/tests/imatmul_test.cpp b/test/tests/imatmul_test.cpp index f787143bb8a1bd005e3c9cbbe4e5d9b923558595..dd12fd4fd6855a79015259c09c9f9c89e1172191 100644 --- a/test/tests/imatmul_test.cpp +++ b/test/tests/imatmul_test.cpp @@ -14,8 +14,12 @@ #include "kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h" #include "kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p_f16p_interface.h" +#include "kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.h" +#include "kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p_f32p_interface.h" #include "kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.h" +#include "kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.h" #include "kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.h" +#include "kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.h" #include "test/common/compare.hpp" #include "test/common/cpu_info.hpp" #include "test/common/matmul_test_common.hpp" @@ -121,7 +125,7 @@ struct IndirectMatMul { using Buffer = std::vector; /// Convenience type for test list -using IndirectMatMulArray = std::array; +using IndirectMatMulArray = std::array; /// Test parameter bundle type using IndirectMatMulTestParams = std::tuple; @@ -142,6 +146,19 @@ const kai_imatmul_clamp_f16_f16p_f16p_ukernel& get_imatmul_clamp_f16_f16p2vlx2_f return ukernel; } +/// Use interface for matmul kernel +const kai_imatmul_clamp_f32_f32p_f32p_ukernel& get_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa() { + static kai_imatmul_clamp_f32_f32p_f32p_ukernel ukernel; + ukernel.get_m_step = kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa; + ukernel.get_n_step = kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa; + ukernel.get_lhs_packed_offset = kai_get_lhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa; + ukernel.get_rhs_packed_offset = kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa; + ukernel.get_dst_offset = kai_get_dst_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa; + ukernel.get_dst_size = kai_get_dst_size_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa; + ukernel.run_imatmul = kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa; + return ukernel; +} + /// Retreive the test list const IndirectMatMulArray& get_indirect_matmul_methods() { static IndirectMatMulArray indirect_matmul_methods{}; @@ -185,6 +202,45 @@ const IndirectMatMulArray& get_indirect_matmul_methods() { indirect_matmul_methods[0].imatmul.get_dst_size = ukernel_f16.get_dst_size; indirect_matmul_methods[0].imatmul.imatmul = ukernel_f16.run_imatmul; + // F32 IMATMUL //////////////////////////////////////////////////////////// + indirect_matmul_methods[1].name = "indirect_matmul_f32_f32p_f32p_2vlx2vl_sme2_mopa"; + indirect_matmul_methods[1].is_supported = cpu_has_sme2; + indirect_matmul_methods[1].pack_shape.m = 2 * get_sme_vector_length(); + indirect_matmul_methods[1].pack_shape.n = 2 * get_sme_vector_length(); + indirect_matmul_methods[1].pack_shape.k = sizeof(int32_t); + indirect_matmul_methods[1].format.lhs = DataFormat(DataType::FP32); + indirect_matmul_methods[1].format.rhs = DataFormat(DataType::FP32); + indirect_matmul_methods[1].format.bias = DataFormat(DataType::FP32); + indirect_matmul_methods[1].format.out = DataFormat(DataType::FP32); + + // LHS + indirect_matmul_methods[1].lhs.get_m_step = kai_get_m_step_lhs_imatmul_pack_x32p2vlx1_x32p_sme; + indirect_matmul_methods[1].lhs.get_lhs_packed_offset = + kai_get_lhs_packed_offset_lhs_imatmul_pack_x32p2vlx1_x32p_sme; + indirect_matmul_methods[1].lhs.get_lhs_packed_size = kai_get_lhs_packed_size_lhs_imatmul_pack_x32p2vlx1_x32p_sme; + indirect_matmul_methods[1].lhs.pack = kai_run_lhs_imatmul_pack_x32p2vlx1_x32p_sme; + + // RHS + indirect_matmul_methods[1].rhs.get_n_step = kai_get_n_step_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme; + indirect_matmul_methods[1].rhs.get_rhs_offset = kai_get_rhs_offset_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme; + indirect_matmul_methods[1].rhs.get_bias_offset = kai_get_bias_offset_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme; + indirect_matmul_methods[1].rhs.get_rhs_packed_offset = + kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme; + indirect_matmul_methods[1].rhs.get_rhs_packed_size = + kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme; + indirect_matmul_methods[1].rhs.pack = kai_run_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme; + + // IMATMUL + const kai_imatmul_clamp_f32_f32p_f32p_ukernel& ukernel_f32 = + get_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(); + indirect_matmul_methods[1].imatmul.get_m_step = ukernel_f32.get_m_step; + indirect_matmul_methods[1].imatmul.get_n_step = ukernel_f32.get_n_step; + indirect_matmul_methods[1].imatmul.get_lhs_packed_offset = ukernel_f32.get_lhs_packed_offset; + indirect_matmul_methods[1].imatmul.get_rhs_packed_offset = ukernel_f32.get_rhs_packed_offset; + indirect_matmul_methods[1].imatmul.get_dst_offset = ukernel_f32.get_dst_offset; + indirect_matmul_methods[1].imatmul.get_dst_size = ukernel_f32.get_dst_size; + indirect_matmul_methods[1].imatmul.imatmul = ukernel_f32.run_imatmul; + return indirect_matmul_methods; }