diff --git a/CHANGELOG.md b/CHANGELOG.md index 127216919dd43400290cd33db7eaf0b8156b1ba3..1986fa1c978554e92fd690e700852ee94fee54fc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,9 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo - New SME micro-kernels: - Matrix multiplication (1xN) of F32 LHS and RHS with F32 output, using instructions compatible with FEAT_SME. - Matrix multiplication (1xN) of F16 LHS and RHS with F16 output, using instructions compatible with FEAT_SME. +- Convert SME transposed RHS packing micro-kernels to pure assembly: + - kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme + - kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme - Fixes - Update kai_kernel_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa to improve accuracy diff --git a/CMakeLists.txt b/CMakeLists.txt index 81068b62280ee89ebfa81f9f7e4202b9a6b5a040..054d5b4597d65d37a29ed62cc7263dc554409505 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -275,12 +275,14 @@ set(KLEIDIAI_FILES_SME_ASM kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme_asm.S kai/ukernels/matmul/pack/kai_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme_asm.S + kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.c + kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme_asm.S + kai/ukernels/matmul/pack/kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme.c + kai/ukernels/matmul/pack/kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme_asm.S ) set(KLEIDIAI_FILES_SME ${KLEIDIAI_FILES_SME_ASM} - kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.c - kai/ukernels/matmul/pack/kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme.c ) set(KLEIDIAI_FILES_SME2_ASM diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 7c0abaf233edd1c1c7b21da345e70cae82fc5f41..ee19a7061f3bb5806c519339f12985405d1c3f11 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -149,8 +149,6 @@ I8MM_KERNELS_ASM = [ SME_KERNELS = [ "pack/kai_lhs_pack_bf16p2vlx2_f32_sme", "pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme", - "pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme", - "pack/kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme", ] # buildifier: keep sorted @@ -172,6 +170,8 @@ SME_KERNELS_ASM = [ "pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme", "pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme", "pack/kai_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme", + "pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme", + "pack/kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme", ] # buildifier: keep sorted diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.c index b5fe9904812c5078bc799ec16bf56efdda568fbe..57592fe8d60d30825539742b4752ec2317a3dee3 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.c @@ -4,10 +4,7 @@ // SPDX-License-Identifier: Apache-2.0 // -// Do not flag up inline assembly blocks -#pragma GCC diagnostic ignored "-Woverlength-strings" - -#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64) #error This file must be compiled for AArch64, FEAT_SVE2. #else // Architectural features check. @@ -18,14 +15,20 @@ #include "kai/kai_common.h" -static const size_t kai_nr = 2; -static const size_t kai_kr = 1; -static const size_t kai_sr = 1; -static const size_t kai_num_bytes_data = 4; -static const size_t kai_num_bytes_bias = 4; +enum { + NR = 2, + KR = 1, + SR = 1, + NUM_BYTES_DATA = 4, + NUM_BYTES_BIAS = 4, + MAX_BLOCK_HEIGHT = (NR * (KAI_SME_VEC_LENGTH_MAX_BYTES / NUM_BYTES_DATA) / KR), +}; + +void kai_kernel_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme( + size_t height, size_t width, const void* in, void* out, const void* bias); static size_t get_block_height(void) { - const size_t block_height = kai_nr * kai_get_sme_vector_length_u32() / kai_kr; + const size_t block_height = NR * kai_get_sme_vector_length_u32() / KR; return block_height; } @@ -42,11 +45,11 @@ size_t kai_get_rhs_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(size_t n_idx size_t kai_get_bias_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(size_t n_idx) { KAI_ASSUME(n_idx % get_block_height() == 0); - return n_idx * kai_num_bytes_bias; + return n_idx * NUM_BYTES_BIAS; } size_t kai_get_rhs_packed_stride_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(size_t k) { - return kai_num_bytes_bias + kai_roundup(k, kai_kr) * kai_num_bytes_data; + return NUM_BYTES_BIAS + kai_roundup(k, KR) * NUM_BYTES_DATA; } size_t kai_get_rhs_packed_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(size_t n_idx, size_t k) { @@ -64,8 +67,8 @@ void kai_run_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme( const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params) { KAI_ASSUME(num_groups == 1); KAI_ASSUME(nr == get_block_height()); - KAI_ASSUME(kr == kai_kr); - KAI_ASSUME(sr == kai_sr); + KAI_ASSUME(kr == KR); + KAI_ASSUME(sr == SR); KAI_ASSUME(rhs != NULL); KAI_ASSUME(bias != NULL); KAI_ASSUME(scale == NULL); @@ -75,286 +78,24 @@ void kai_run_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme( const size_t block_height = get_block_height(); const size_t width = k; - const size_t row_offset = 0; - const void* in[block_height]; + const uint8_t* in[MAX_BLOCK_HEIGHT]; uint8_t* rhs_packed_ptr = rhs_packed; const uint8_t* rhs_ptr = rhs; const uint8_t* bias_ptr = bias; for (size_t block_y = 0; block_y < n; block_y += block_height) { const size_t height = KAI_MIN(n - block_y, block_height); - void* out = rhs_packed_ptr + block_y * (kai_num_bytes_bias + kai_roundup(k, kai_kr) * kai_num_bytes_data); + uint8_t* out = rhs_packed_ptr + block_y * (NUM_BYTES_BIAS + kai_roundup(k, KR) * NUM_BYTES_DATA); for (size_t y = 0; y < height; y++) { in[y] = rhs_ptr + (block_y + y) * rhs_stride; } - __asm__ __volatile__( - ".inst 0xd503477f // SMSTART ZA\n" - "ptrue p1.b\n" - "cbz %x[bias], 1f\n" - "mov x20, %x[height]\n" - "whilelt p0.s, XZR, %x[height]\n" - "decw x20\n" - "ld1w { z16.s }, p0/Z, [%x[bias]]\n" - "whilelt p0.s, XZR, x20\n" - "st1w { z16.s }, p1, [%x[out]]\n" - "ld1w { z16.s }, p0/Z, [%x[bias], #1, MUL VL]\n" - "st1w { z16.s }, p1, [%x[out], #1, MUL VL]\n" - "addvl %x[out], %x[out], #2\n" - "1:" // Bias: Done - "mov x21, %x[width]\n" - "cntw x17\n" - "incw x21\n" - "mov x20, %x[width]\n" - "sub x21, x21, #0x1\n" - "sub x16, x17, #0x1\n" - "udiv x21, x21, x17\n" // n_passes = ceildiv(width, VL) - "ands x16, x20, x16\n" - "sub x20, x21, #0x1\n" - "sub x15, x17, #0x2\n" - "mov x14, #0x0\n" - "mov x11, %x[in]\n" - "add x10, %x[in], x17, LSL #3\n" - "cntw x9, ALL, MUL #2\n" - "ldr x28, [x11, #0x0]\n" - "cntw x27, ALL, MUL #3\n" - "lsr x20, x20, #0x1\n" // n_loops = (n_passes - 1) / 2 - "ldr x26, [x10, #0x0]\n" - "and x25, x21, #0x1\n" // odd_tail = bool(n_passes & 0x1) - "csel x16, x16, x17, NE\n" - "ldr x24, [x11, #0x8]\n" - "ptrue p12.s\n" - "whilelt p11.s, XZR, %x[height]\n" - "ldr x21, [x10, #0x8]\n" - "whilelt p10.s, x17, %x[height]\n" - "mov x23, %x[row_offset]\n" - "mov x22, %x[out]\n" - "whilelt p9.s, x14, %x[width]\n" - "whilelt p8.s, x14, %x[width]\n" - "add x11, x11, #0x10\n" - "add x10, x10, #0x10\n" - "mov x12, #0x0\n" - "cbz x15, 3f\n" - "2:" // K loop: Charge: Loop - ".inst 0x25306163 // psel p3.s, p8.s/Z, p11.s[w12]\n" - ".inst 0x25306142 // psel p2.s, p8.s/Z, p10.s[w12]\n" - ".inst 0x25706161 // psel p1.s, p8.s/Z, p11.s[w12, #1]\n" - ".inst 0x25706140 // psel p0.s, p8.s/Z, p10.s[w12, #1]\n" - ".inst 0xe0970f80 // ld1w { za0h.s[x12] }, p3/Z, [x28, x23, LSL #2]\n" - "ldr x28, [x11, #0x0]\n" - ".inst 0xe0970b44 // ld1w { za1h.s[x12] }, p2/Z, [x26, x23, LSL #2]\n" - "ldr x26, [x10, #0x0]\n" - ".inst 0xe0970701 // ld1w { za0h.s[x12, #1] }, p1/Z, [x24, x23, LSL #2]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe09702a5 // ld1w { za1h.s[x12, #1] }, p0/Z, [x21, x23, LSL #2]\n" - "add x12, x12, #0x2\n" - "ldr x21, [x10, #0x8]\n" - "add x10, x10, #0x10\n" - "cmp x12, x15\n" - "blt 2b\n" - "3:" // K loop: Charge: End - ".inst 0x25306163 // psel p3.s, p8.s/Z, p11.s[w12]\n" - ".inst 0x25306142 // psel p2.s, p8.s/Z, p10.s[w12]\n" - ".inst 0x25706161 // psel p1.s, p8.s/Z, p11.s[w12, #1]\n" - ".inst 0x25706140 // psel p0.s, p8.s/Z, p10.s[w12, #1]\n" - "mov x11, %x[in]\n" - "add x10, %x[in], x17, LSL #3\n" - ".inst 0xe0970f80 // ld1w { za0h.s[x12] }, p3/Z, [x28, x23, LSL #2]\n" - "ldr x28, [x11, #0x0]\n" - "incw x14\n" - ".inst 0xe0970b44 // ld1w { za1h.s[x12] }, p2/Z, [x26, x23, LSL #2]\n" - "ldr x26, [x10, #0x0]\n" - ".inst 0xe0970701 // ld1w { za0h.s[x12, #1] }, p1/Z, [x24, x23, LSL #2]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe09702a5 // ld1w { za1h.s[x12, #1] }, p0/Z, [x21, x23, LSL #2]\n" - "ldr x21, [x10, #0x8]\n" - "add x10, x10, #0x10\n" - "incw x23\n" - "cbz x20, 9f\n" - "mov x20, x20\n" - "4:" // K loop: Main loop - "whilelt p8.s, x14, %x[width]\n" - "mov x13, #0x0\n" - "cbz x15, 6f\n" - "5:" // K loop: Main loop: First: Loop - ".inst 0x25316160 // psel p0.s, p8.s/Z, p11.s[w13]\n" - ".inst 0x25316142 // psel p2.s, p8.s/Z, p10.s[w13]\n" - ".inst 0x25716161 // psel p1.s, p8.s/Z, p11.s[w13, #1]\n" - ".inst 0x25716143 // psel p3.s, p8.s/Z, p10.s[w13, #1]\n" - ".inst 0xe0972388 // ld1w { za2h.s[x13] }, p0/Z, [x28, x23, LSL #2]\n" - ".inst 0x25317120 // psel p0.s, p12.s/Z, p9.s[w13]\n" - "ldr x28, [x11, #0x0]\n" - ".inst 0xe0972b4c // ld1w { za3h.s[x13] }, p2/Z, [x26, x23, LSL #2]\n" - ".inst 0x25317122 // psel p2.s, p12.s/Z, p9.s[w13]\n" - "ldr x26, [x10, #0x0]\n" - ".inst 0xe0972709 // ld1w { za2h.s[x13, #1] }, p1/Z, [x24, x23, LSL #2]\n" - ".inst 0x25717121 // psel p1.s, p12.s/Z, p9.s[w13, #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe0972ead // ld1w { za3h.s[x13, #1] }, p3/Z, [x21, x23, LSL #2]\n" - "ldr x21, [x10, #0x8]\n" - ".inst 0xe0bfa2c0 // st1w { za0v.s[x13] }, p0/Z, [x22, XZR, LSL #2]\n" - ".inst 0x25717120 // psel p0.s, p12.s/Z, p9.s[w13, #1]\n" - ".inst 0xe0b1aac4 // st1w { za1v.s[x13] }, p2/Z, [x22, x17, LSL #2]\n" - "add x10, x10, #0x10\n" - ".inst 0xe0a9a6c1 // st1w { za0v.s[x13, #1] }, p1/Z, [x22, x9, LSL #2]\n" - ".inst 0xe0bba2c5 // st1w { za1v.s[x13, #1] }, p0/Z, [x22, x27, LSL #2]\n" - "add x13, x13, #0x2\n" - "addvl x22, x22, #4\n" - "cmp x13, x15\n" - "blt 5b\n" - "6:" // K loop: Main loop: First: Tail - ".inst 0x25316160 // psel p0.s, p8.s/Z, p11.s[w13]\n" - ".inst 0x25316142 // psel p2.s, p8.s/Z, p10.s[w13]\n" - ".inst 0x25716161 // psel p1.s, p8.s/Z, p11.s[w13, #1]\n" - ".inst 0x25716143 // psel p3.s, p8.s/Z, p10.s[w13, #1]\n" - "mov x11, %x[in]\n" - "add x10, %x[in], x17, LSL #3\n" - ".inst 0xe0972388 // ld1w { za2h.s[x13] }, p0/Z, [x28, x23, LSL #2]\n" - ".inst 0x25317120 // psel p0.s, p12.s/Z, p9.s[w13]\n" - "ldr x28, [x11, #0x0]\n" - "mov x12, #0x0\n" - ".inst 0xe0972b4c // ld1w { za3h.s[x13] }, p2/Z, [x26, x23, LSL #2]\n" - ".inst 0x25317122 // psel p2.s, p12.s/Z, p9.s[w13]\n" - "ldr x26, [x10, #0x0]\n" - ".inst 0xe0972709 // ld1w { za2h.s[x13, #1] }, p1/Z, [x24, x23, LSL #2]\n" - ".inst 0x25717121 // psel p1.s, p12.s/Z, p9.s[w13, #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe0972ead // ld1w { za3h.s[x13, #1] }, p3/Z, [x21, x23, LSL #2]\n" - "ldr x21, [x10, #0x8]\n" - ".inst 0xe0bfa2c0 // st1w { za0v.s[x13] }, p0/Z, [x22, XZR, LSL #2]\n" - ".inst 0x25717120 // psel p0.s, p12.s/Z, p9.s[w13, #1]\n" - ".inst 0xe0b1aac4 // st1w { za1v.s[x13] }, p2/Z, [x22, x17, LSL #2]\n" - "whilelt p9.s, x14, %x[width]\n" - "incw x14\n" - ".inst 0xe0a9a6c1 // st1w { za0v.s[x13, #1] }, p1/Z, [x22, x9, LSL #2]\n" - "add x10, x10, #0x10\n" - "incw x23\n" - ".inst 0xe0bba2c5 // st1w { za1v.s[x13, #1] }, p0/Z, [x22, x27, LSL #2]\n" - "addvl x22, x22, #4\n" - "whilelt p8.s, x14, %x[width]\n" - "cbz x15, 8f\n" - "7:" // K loop: Main loop: Second: Loop - ".inst 0x25306160 // psel p0.s, p8.s/Z, p11.s[w12]\n" - ".inst 0x25306142 // psel p2.s, p8.s/Z, p10.s[w12]\n" - ".inst 0x25706161 // psel p1.s, p8.s/Z, p11.s[w12, #1]\n" - ".inst 0x25706143 // psel p3.s, p8.s/Z, p10.s[w12, #1]\n" - ".inst 0xe0970380 // ld1w { za0h.s[x12] }, p0/Z, [x28, x23, LSL #2]\n" - ".inst 0x25307120 // psel p0.s, p12.s/Z, p9.s[w12]\n" - "ldr x28, [x11, #0x0]\n" - ".inst 0xe0970b44 // ld1w { za1h.s[x12] }, p2/Z, [x26, x23, LSL #2]\n" - ".inst 0x25307122 // psel p2.s, p12.s/Z, p9.s[w12]\n" - "ldr x26, [x10, #0x0]\n" - ".inst 0xe0970701 // ld1w { za0h.s[x12, #1] }, p1/Z, [x24, x23, LSL #2]\n" - ".inst 0x25707121 // psel p1.s, p12.s/Z, p9.s[w12, #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe0970ea5 // ld1w { za1h.s[x12, #1] }, p3/Z, [x21, x23, LSL #2]\n" - "ldr x21, [x10, #0x8]\n" - ".inst 0xe0bf82c8 // st1w { za2v.s[x12] }, p0/Z, [x22, XZR, LSL #2]\n" - ".inst 0x25707120 // psel p0.s, p12.s/Z, p9.s[w12, #1]\n" - ".inst 0xe0b18acc // st1w { za3v.s[x12] }, p2/Z, [x22, x17, LSL #2]\n" - "add x10, x10, #0x10\n" - ".inst 0xe0a986c9 // st1w { za2v.s[x12, #1] }, p1/Z, [x22, x9, LSL #2]\n" - ".inst 0xe0bb82cd // st1w { za3v.s[x12, #1] }, p0/Z, [x22, x27, LSL #2]\n" - "add x12, x12, #0x2\n" - "addvl x22, x22, #4\n" - "cmp x12, x15\n" - "blt 7b\n" - "8:" // K loop: Main loop: Second: Tail - ".inst 0x25306160 // psel p0.s, p8.s/Z, p11.s[w12]\n" - ".inst 0x25306142 // psel p2.s, p8.s/Z, p10.s[w12]\n" - ".inst 0x25706161 // psel p1.s, p8.s/Z, p11.s[w12, #1]\n" - ".inst 0x25706143 // psel p3.s, p8.s/Z, p10.s[w12, #1]\n" - "mov x11, %x[in]\n" - "add x10, %x[in], x17, LSL #3\n" - ".inst 0xe0970380 // ld1w { za0h.s[x12] }, p0/Z, [x28, x23, LSL #2]\n" - ".inst 0x25307120 // psel p0.s, p12.s/Z, p9.s[w12]\n" - "ldr x28, [x11, #0x0]\n" - ".inst 0xe0970b44 // ld1w { za1h.s[x12] }, p2/Z, [x26, x23, LSL #2]\n" - ".inst 0x25307122 // psel p2.s, p12.s/Z, p9.s[w12]\n" - "ldr x26, [x10, #0x0]\n" - ".inst 0xe0970701 // ld1w { za0h.s[x12, #1] }, p1/Z, [x24, x23, LSL #2]\n" - ".inst 0x25707121 // psel p1.s, p12.s/Z, p9.s[w12, #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe0970ea5 // ld1w { za1h.s[x12, #1] }, p3/Z, [x21, x23, LSL #2]\n" - "ldr x21, [x10, #0x8]\n" - ".inst 0xe0bf82c8 // st1w { za2v.s[x12] }, p0/Z, [x22, XZR, LSL #2]\n" - ".inst 0x25707120 // psel p0.s, p12.s/Z, p9.s[w12, #1]\n" - ".inst 0xe0b18acc // st1w { za3v.s[x12] }, p2/Z, [x22, x17, LSL #2]\n" - "whilelt p9.s, x14, %x[width]\n" - "subs x20, x20, #0x1\n" - ".inst 0xe0a986c9 // st1w { za2v.s[x12, #1] }, p1/Z, [x22, x9, LSL #2]\n" - "add x10, x10, #0x10\n" - "incw x14\n" - ".inst 0xe0bb82cd // st1w { za3v.s[x12, #1] }, p0/Z, [x22, x27, LSL #2]\n" - "addvl x22, x22, #4\n" - "incw x23\n" - "bgt 4b\n" - "9:" // K loop: Tails - "cbnz x25, 12f\n" - "mov x11, %x[in]\n" - "whilelt p8.s, x14, %x[width]\n" - "mov x12, #0x0\n" - "10:" // K loop: Tails: Even: First - ".inst 0x25307123 // psel p3.s, p12.s/Z, p9.s[w12]\n" - ".inst 0x25307122 // psel p2.s, p12.s/Z, p9.s[w12]\n" - ".inst 0x25306161 // psel p1.s, p8.s/Z, p11.s[w12]\n" - ".inst 0x25306140 // psel p0.s, p8.s/Z, p10.s[w12]\n" - ".inst 0xe0bf8ec0 // st1w { za0v.s[x12] }, p3/Z, [x22, XZR, LSL #2]\n" - ".inst 0xe0b18ac4 // st1w { za1v.s[x12] }, p2/Z, [x22, x17, LSL #2]\n" - "addvl x22, x22, #2\n" - "ldr x21, [x11, #0x0]\n" - "ldr x20, [x11, x17, LSL #0x3]\n" - "add x11, x11, #0x8\n" - ".inst 0xe09706a8 // ld1w { za2h.s[x12] }, p1/Z, [x21, x23, LSL #2]\n" - ".inst 0xe097028c // ld1w { za3h.s[x12] }, p0/Z, [x20, x23, LSL #2]\n" - "add x12, x12, #0x1\n" - "cmp x12, x17\n" - "blt 10b\n" - "whilelt p9.s, x14, %x[width]\n" - "whilelt p8.s, x14, %x[width]\n" - "mov x12, #0x0\n" - "11:" // K loop: Tails: Even: Second - ".inst 0x25307121 // psel p1.s, p12.s/Z, p9.s[w12]\n" - ".inst 0x25307120 // psel p0.s, p12.s/Z, p9.s[w12]\n" - ".inst 0xe0bf86c8 // st1w { za2v.s[x12] }, p1/Z, [x22, XZR, LSL #2]\n" - ".inst 0xe0b182cc // st1w { za3v.s[x12] }, p0/Z, [x22, x17, LSL #2]\n" - "add x12, x12, #0x1\n" - "addvl x22, x22, #2\n" - "cmp x12, x16\n" - "blt 11b\n" - "whilelt p8.s, x14, %x[width]\n" - "b 14f\n" - "12:" // K loop: Tails: Odd - "mov x12, #0x0\n" - "13:" // K loop: Tails: Odd: Loop - ".inst 0x25307121 // psel p1.s, p12.s/Z, p9.s[w12]\n" - ".inst 0x25307120 // psel p0.s, p12.s/Z, p9.s[w12]\n" - ".inst 0xe0bf86c0 // st1w { za0v.s[x12] }, p1/Z, [x22, XZR, LSL #2]\n" - ".inst 0xe0b182c4 // st1w { za1v.s[x12] }, p0/Z, [x22, x17, LSL #2]\n" - "add x12, x12, #0x1\n" - "addvl x22, x22, #2\n" - "cmp x12, x16\n" - "blt 13b\n" - "14:" // K loop: End - "mov %x[out], x22\n" - ".inst 0xd503467f // SMSTOP\n" - : [out] "+&r"(out) - : [bias] "r"(bias), [height] "r"(height), [in] "r"(in), [row_offset] "r"(row_offset), [width] "r"(width) - : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", - "p14", "p15", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x20", "x21", "x22", "x23", - "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", - "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", - "z26", "z27", "z28", "z29", "z30", "z31"); + kai_kernel_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme( + height, width, in, out, bias); // NOLINT(bugprone-multi-level-implicit-pointer-conversion) - bias_ptr += height * kai_num_bytes_bias; + bias_ptr += height * NUM_BYTES_BIAS; bias = bias_ptr; } } diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.h b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.h index 77e68f6987ccd0aee9841834fc9ed4628de00c93..74f90f10ee42c121044be5505033dc112868f714 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.h +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.h @@ -34,11 +34,11 @@ size_t kai_get_rhs_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(size_t n_idx /// @return The offset in bytes to the data element. size_t kai_get_bias_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(size_t n_idx); -/// Gets the row stride in bytes to the packed RHS matrix. +/// Gets the row stride in bytes of the packed RHS matrix. /// -/// @param[in] k Number of columns. +/// @param[in] k The number of columns in the transposed RHS matrix. /// -/// @return Row stride in bytes to the packed RHS matrix. +/// @return The row stride in bytes. size_t kai_get_rhs_packed_stride_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(size_t k); /// Gets the offset in bytes to the data element in the packed RHS buffer. @@ -69,7 +69,7 @@ size_t kai_get_rhs_packed_size_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(size_t /// @param[in] num_groups Number of groups. It must be 1. /// @param[in] n Number of columns of the output matrix. /// @param[in] k Common dimension between the LHS and RHS matrix. -/// @param[in] nr Block size in N dimension. It must be 2 * kai_get_sme_vector_length_u32(). +/// @param[in] nr Block size in N dimension. It must match kai_get_n_step_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(). /// @param[in] kr Block size in K dimension. It must be 1. /// @param[in] sr Number of kr splits. It must be 1. /// @param[in] rhs_stride Row stride in bytes of the RHS matrix. diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme_asm.S b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme_asm.S new file mode 100644 index 0000000000000000000000000000000000000000..1f8a72277b01640181296beda0b6d693af3a6b1e --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme_asm.S @@ -0,0 +1,320 @@ +// +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme) +KAI_ASM_FUNCTION_LABEL(kai_kernel_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme) + stp x20, x21, [sp, -144]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + stp d8, d9, [sp, 72] + stp d10, d11, [sp, 88] + stp d12, d13, [sp, 104] + stp d14, d15, [sp, 120] + KAI_ASM_INST(0xd503477f) // SMSTART ZA + mov x22, #0x0 + ptrue p1.b + cbz x4, label_1 + mov x20, #0x0 + whilelt p0.s, XZR, x0 + incw x20 + ld1w { z16.s }, p0/Z, [x4] + whilelt p0.s, x20, x0 + st1w { z16.s }, p1, [x3] + ld1w { z16.s }, p0/Z, [x4, #1, MUL VL] + st1w { z16.s }, p1, [x3, #1, MUL VL] + addvl x3, x3, #2 +KAI_ASM_LABEL(label_1) // Bias: Done + mov x21, x1 + cntw x16 + incw x21 + mov x20, x1 + sub x21, x21, #0x1 + sub x15, x16, #0x1 + udiv x21, x21, x16 // n_passes = ceildiv(width, VL) + ands x15, x20, x15 + sub x20, x21, #0x1 + sub x14, x16, #0x2 + mov x11, #0x0 + mov x10, x2 + add x9, x2, x16, LSL #3 + cntw x28, ALL, MUL #2 + ldr x27, [x10, #0x0] + cntw x26, ALL, MUL #3 + lsr x20, x20, #0x1 // n_loops = (n_passes - 1) / 2 + ldr x25, [x9, #0x0] + and x24, x21, #0x1 // odd_tail = bool(n_passes & 0x1) + csel x15, x15, x16, NE + ldr x23, [x10, #0x8] + ptrue p12.s + whilelt p11.s, XZR, x0 + ldr x21, [x9, #0x8] + whilelt p10.s, x16, x0 + mov x22, x22 + whilelt p9.s, x11, x1 + whilelt p8.s, x11, x1 + add x10, x10, #0x10 + add x9, x9, #0x10 + mov x12, #0x0 + cbz x14, label_3 +KAI_ASM_LABEL(label_2) // K loop: Charge: Loop + KAI_ASM_INST(0x25306163) // psel p3.s, p8.s/Z, p11.s[w12] + KAI_ASM_INST(0x25306142) // psel p2.s, p8.s/Z, p10.s[w12] + KAI_ASM_INST(0x25706161) // psel p1.s, p8.s/Z, p11.s[w12, #1] + KAI_ASM_INST(0x25706140) // psel p0.s, p8.s/Z, p10.s[w12, #1] + KAI_ASM_INST(0xe0960f60) // ld1w { za0h.s[x12] }, p3/Z, [x27, x22, LSL #2] + ldr x27, [x10, #0x0] + KAI_ASM_INST(0xe0960b24) // ld1w { za1h.s[x12] }, p2/Z, [x25, x22, LSL #2] + ldr x25, [x9, #0x0] + KAI_ASM_INST(0xe09606e1) // ld1w { za0h.s[x12, #1] }, p1/Z, [x23, x22, LSL #2] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe09602a5) // ld1w { za1h.s[x12, #1] }, p0/Z, [x21, x22, LSL #2] + add x12, x12, #0x2 + ldr x21, [x9, #0x8] + add x9, x9, #0x10 + cmp x12, x14 + blt label_2 +KAI_ASM_LABEL(label_3) // K loop: Charge: End + KAI_ASM_INST(0x25306163) // psel p3.s, p8.s/Z, p11.s[w12] + KAI_ASM_INST(0x25306142) // psel p2.s, p8.s/Z, p10.s[w12] + KAI_ASM_INST(0x25706161) // psel p1.s, p8.s/Z, p11.s[w12, #1] + KAI_ASM_INST(0x25706140) // psel p0.s, p8.s/Z, p10.s[w12, #1] + mov x10, x2 + add x9, x2, x16, LSL #3 + KAI_ASM_INST(0xe0960f60) // ld1w { za0h.s[x12] }, p3/Z, [x27, x22, LSL #2] + ldr x27, [x10, #0x0] + incw x11 + KAI_ASM_INST(0xe0960b24) // ld1w { za1h.s[x12] }, p2/Z, [x25, x22, LSL #2] + ldr x25, [x9, #0x0] + KAI_ASM_INST(0xe09606e1) // ld1w { za0h.s[x12, #1] }, p1/Z, [x23, x22, LSL #2] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe09602a5) // ld1w { za1h.s[x12, #1] }, p0/Z, [x21, x22, LSL #2] + ldr x21, [x9, #0x8] + add x9, x9, #0x10 + incw x22 + cbz x20, label_9 + mov x20, x20 +KAI_ASM_LABEL(label_4) // K loop: Main loop + whilelt p8.s, x11, x1 + mov x13, #0x0 + cbz x14, label_6 +KAI_ASM_LABEL(label_5) // K loop: Main loop: First: Loop + KAI_ASM_INST(0x25316160) // psel p0.s, p8.s/Z, p11.s[w13] + KAI_ASM_INST(0x25316142) // psel p2.s, p8.s/Z, p10.s[w13] + KAI_ASM_INST(0x25716161) // psel p1.s, p8.s/Z, p11.s[w13, #1] + KAI_ASM_INST(0x25716143) // psel p3.s, p8.s/Z, p10.s[w13, #1] + KAI_ASM_INST(0xe0962368) // ld1w { za2h.s[x13] }, p0/Z, [x27, x22, LSL #2] + KAI_ASM_INST(0x25317120) // psel p0.s, p12.s/Z, p9.s[w13] + ldr x27, [x10, #0x0] + KAI_ASM_INST(0xe0962b2c) // ld1w { za3h.s[x13] }, p2/Z, [x25, x22, LSL #2] + KAI_ASM_INST(0x25317122) // psel p2.s, p12.s/Z, p9.s[w13] + ldr x25, [x9, #0x0] + KAI_ASM_INST(0xe09626e9) // ld1w { za2h.s[x13, #1] }, p1/Z, [x23, x22, LSL #2] + KAI_ASM_INST(0x25717121) // psel p1.s, p12.s/Z, p9.s[w13, #1] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe0962ead) // ld1w { za3h.s[x13, #1] }, p3/Z, [x21, x22, LSL #2] + ldr x21, [x9, #0x8] + KAI_ASM_INST(0xe0bfa060) // st1w { za0v.s[x13] }, p0/Z, [x3, XZR, LSL #2] + KAI_ASM_INST(0x25717120) // psel p0.s, p12.s/Z, p9.s[w13, #1] + KAI_ASM_INST(0xe0b0a864) // st1w { za1v.s[x13] }, p2/Z, [x3, x16, LSL #2] + add x9, x9, #0x10 + KAI_ASM_INST(0xe0bca461) // st1w { za0v.s[x13, #1] }, p1/Z, [x3, x28, LSL #2] + KAI_ASM_INST(0xe0baa065) // st1w { za1v.s[x13, #1] }, p0/Z, [x3, x26, LSL #2] + add x13, x13, #0x2 + addvl x3, x3, #4 + cmp x13, x14 + blt label_5 +KAI_ASM_LABEL(label_6) // K loop: Main loop: First: Tail + KAI_ASM_INST(0x25316160) // psel p0.s, p8.s/Z, p11.s[w13] + KAI_ASM_INST(0x25316142) // psel p2.s, p8.s/Z, p10.s[w13] + KAI_ASM_INST(0x25716161) // psel p1.s, p8.s/Z, p11.s[w13, #1] + KAI_ASM_INST(0x25716143) // psel p3.s, p8.s/Z, p10.s[w13, #1] + mov x10, x2 + add x9, x2, x16, LSL #3 + KAI_ASM_INST(0xe0962368) // ld1w { za2h.s[x13] }, p0/Z, [x27, x22, LSL #2] + KAI_ASM_INST(0x25317120) // psel p0.s, p12.s/Z, p9.s[w13] + ldr x27, [x10, #0x0] + mov x12, #0x0 + KAI_ASM_INST(0xe0962b2c) // ld1w { za3h.s[x13] }, p2/Z, [x25, x22, LSL #2] + KAI_ASM_INST(0x25317122) // psel p2.s, p12.s/Z, p9.s[w13] + ldr x25, [x9, #0x0] + KAI_ASM_INST(0xe09626e9) // ld1w { za2h.s[x13, #1] }, p1/Z, [x23, x22, LSL #2] + KAI_ASM_INST(0x25717121) // psel p1.s, p12.s/Z, p9.s[w13, #1] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe0962ead) // ld1w { za3h.s[x13, #1] }, p3/Z, [x21, x22, LSL #2] + ldr x21, [x9, #0x8] + KAI_ASM_INST(0xe0bfa060) // st1w { za0v.s[x13] }, p0/Z, [x3, XZR, LSL #2] + KAI_ASM_INST(0x25717120) // psel p0.s, p12.s/Z, p9.s[w13, #1] + KAI_ASM_INST(0xe0b0a864) // st1w { za1v.s[x13] }, p2/Z, [x3, x16, LSL #2] + whilelt p9.s, x11, x1 + incw x11 + KAI_ASM_INST(0xe0bca461) // st1w { za0v.s[x13, #1] }, p1/Z, [x3, x28, LSL #2] + add x9, x9, #0x10 + incw x22 + KAI_ASM_INST(0xe0baa065) // st1w { za1v.s[x13, #1] }, p0/Z, [x3, x26, LSL #2] + addvl x3, x3, #4 + whilelt p8.s, x11, x1 + cbz x14, label_8 +KAI_ASM_LABEL(label_7) // K loop: Main loop: Second: Loop + KAI_ASM_INST(0x25306160) // psel p0.s, p8.s/Z, p11.s[w12] + KAI_ASM_INST(0x25306142) // psel p2.s, p8.s/Z, p10.s[w12] + KAI_ASM_INST(0x25706161) // psel p1.s, p8.s/Z, p11.s[w12, #1] + KAI_ASM_INST(0x25706143) // psel p3.s, p8.s/Z, p10.s[w12, #1] + KAI_ASM_INST(0xe0960360) // ld1w { za0h.s[x12] }, p0/Z, [x27, x22, LSL #2] + KAI_ASM_INST(0x25307120) // psel p0.s, p12.s/Z, p9.s[w12] + ldr x27, [x10, #0x0] + KAI_ASM_INST(0xe0960b24) // ld1w { za1h.s[x12] }, p2/Z, [x25, x22, LSL #2] + KAI_ASM_INST(0x25307122) // psel p2.s, p12.s/Z, p9.s[w12] + ldr x25, [x9, #0x0] + KAI_ASM_INST(0xe09606e1) // ld1w { za0h.s[x12, #1] }, p1/Z, [x23, x22, LSL #2] + KAI_ASM_INST(0x25707121) // psel p1.s, p12.s/Z, p9.s[w12, #1] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe0960ea5) // ld1w { za1h.s[x12, #1] }, p3/Z, [x21, x22, LSL #2] + ldr x21, [x9, #0x8] + KAI_ASM_INST(0xe0bf8068) // st1w { za2v.s[x12] }, p0/Z, [x3, XZR, LSL #2] + KAI_ASM_INST(0x25707120) // psel p0.s, p12.s/Z, p9.s[w12, #1] + KAI_ASM_INST(0xe0b0886c) // st1w { za3v.s[x12] }, p2/Z, [x3, x16, LSL #2] + add x9, x9, #0x10 + KAI_ASM_INST(0xe0bc8469) // st1w { za2v.s[x12, #1] }, p1/Z, [x3, x28, LSL #2] + KAI_ASM_INST(0xe0ba806d) // st1w { za3v.s[x12, #1] }, p0/Z, [x3, x26, LSL #2] + add x12, x12, #0x2 + addvl x3, x3, #4 + cmp x12, x14 + blt label_7 +KAI_ASM_LABEL(label_8) // K loop: Main loop: Second: Tail + KAI_ASM_INST(0x25306160) // psel p0.s, p8.s/Z, p11.s[w12] + KAI_ASM_INST(0x25306142) // psel p2.s, p8.s/Z, p10.s[w12] + KAI_ASM_INST(0x25706161) // psel p1.s, p8.s/Z, p11.s[w12, #1] + KAI_ASM_INST(0x25706143) // psel p3.s, p8.s/Z, p10.s[w12, #1] + mov x10, x2 + add x9, x2, x16, LSL #3 + KAI_ASM_INST(0xe0960360) // ld1w { za0h.s[x12] }, p0/Z, [x27, x22, LSL #2] + KAI_ASM_INST(0x25307120) // psel p0.s, p12.s/Z, p9.s[w12] + ldr x27, [x10, #0x0] + KAI_ASM_INST(0xe0960b24) // ld1w { za1h.s[x12] }, p2/Z, [x25, x22, LSL #2] + KAI_ASM_INST(0x25307122) // psel p2.s, p12.s/Z, p9.s[w12] + ldr x25, [x9, #0x0] + KAI_ASM_INST(0xe09606e1) // ld1w { za0h.s[x12, #1] }, p1/Z, [x23, x22, LSL #2] + KAI_ASM_INST(0x25707121) // psel p1.s, p12.s/Z, p9.s[w12, #1] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe0960ea5) // ld1w { za1h.s[x12, #1] }, p3/Z, [x21, x22, LSL #2] + ldr x21, [x9, #0x8] + KAI_ASM_INST(0xe0bf8068) // st1w { za2v.s[x12] }, p0/Z, [x3, XZR, LSL #2] + KAI_ASM_INST(0x25707120) // psel p0.s, p12.s/Z, p9.s[w12, #1] + KAI_ASM_INST(0xe0b0886c) // st1w { za3v.s[x12] }, p2/Z, [x3, x16, LSL #2] + whilelt p9.s, x11, x1 + subs x20, x20, #0x1 + KAI_ASM_INST(0xe0bc8469) // st1w { za2v.s[x12, #1] }, p1/Z, [x3, x28, LSL #2] + add x9, x9, #0x10 + incw x11 + KAI_ASM_INST(0xe0ba806d) // st1w { za3v.s[x12, #1] }, p0/Z, [x3, x26, LSL #2] + addvl x3, x3, #4 + incw x22 + bgt label_4 +KAI_ASM_LABEL(label_9) // K loop: Tails + cbnz x24, label_12 + mov x10, x2 + whilelt p8.s, x11, x1 + mov x12, #0x0 +KAI_ASM_LABEL(label_10) // K loop: Tails: Even: First + KAI_ASM_INST(0x25307123) // psel p3.s, p12.s/Z, p9.s[w12] + KAI_ASM_INST(0x25307122) // psel p2.s, p12.s/Z, p9.s[w12] + KAI_ASM_INST(0x25306161) // psel p1.s, p8.s/Z, p11.s[w12] + KAI_ASM_INST(0x25306140) // psel p0.s, p8.s/Z, p10.s[w12] + KAI_ASM_INST(0xe0bf8c60) // st1w { za0v.s[x12] }, p3/Z, [x3, XZR, LSL #2] + KAI_ASM_INST(0xe0b08864) // st1w { za1v.s[x12] }, p2/Z, [x3, x16, LSL #2] + addvl x3, x3, #2 + ldr x21, [x10, #0x0] + ldr x20, [x10, x16, LSL #0x3] + add x10, x10, #0x8 + KAI_ASM_INST(0xe09606a8) // ld1w { za2h.s[x12] }, p1/Z, [x21, x22, LSL #2] + KAI_ASM_INST(0xe096028c) // ld1w { za3h.s[x12] }, p0/Z, [x20, x22, LSL #2] + add x12, x12, #0x1 + cmp x12, x16 + blt label_10 + whilelt p9.s, x11, x1 + whilelt p8.s, x11, x1 + mov x12, #0x0 +KAI_ASM_LABEL(label_11) // K loop: Tails: Even: Second + KAI_ASM_INST(0x25307121) // psel p1.s, p12.s/Z, p9.s[w12] + KAI_ASM_INST(0x25307120) // psel p0.s, p12.s/Z, p9.s[w12] + KAI_ASM_INST(0xe0bf8468) // st1w { za2v.s[x12] }, p1/Z, [x3, XZR, LSL #2] + KAI_ASM_INST(0xe0b0806c) // st1w { za3v.s[x12] }, p0/Z, [x3, x16, LSL #2] + add x12, x12, #0x1 + addvl x3, x3, #2 + cmp x12, x15 + blt label_11 + whilelt p8.s, x11, x1 + b label_14 +KAI_ASM_LABEL(label_12) // K loop: Tails: Odd + mov x12, #0x0 +KAI_ASM_LABEL(label_13) // K loop: Tails: Odd: Loop + KAI_ASM_INST(0x25307121) // psel p1.s, p12.s/Z, p9.s[w12] + KAI_ASM_INST(0x25307120) // psel p0.s, p12.s/Z, p9.s[w12] + KAI_ASM_INST(0xe0bf8460) // st1w { za0v.s[x12] }, p1/Z, [x3, XZR, LSL #2] + KAI_ASM_INST(0xe0b08064) // st1w { za1v.s[x12] }, p0/Z, [x3, x16, LSL #2] + add x12, x12, #0x1 + addvl x3, x3, #2 + cmp x12, x15 + blt label_13 +KAI_ASM_LABEL(label_14) // K loop: End + KAI_ASM_INST(0xd503467f) // SMSTOP + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp d8, d9, [sp, 72] + ldp d10, d11, [sp, 88] + ldp d12, d13, [sp, 104] + ldp d14, d15, [sp, 120] + ldp x20, x21, [sp], 144 + ret + KAI_ASM_FUNCTION_END(kai_kernel_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme.c index d292185823789d084ec41733df292b8260bec413..2fc11ef94cc855c470890db6bce7500bea3c3d88 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme.c @@ -4,10 +4,7 @@ // SPDX-License-Identifier: Apache-2.0 // -// Do not flag up inline assembly blocks -#pragma GCC diagnostic ignored "-Woverlength-strings" - -#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64) #error This file must be compiled for AArch64, FEAT_SVE2. #else // Architectural features check. @@ -18,55 +15,60 @@ #include "kai/kai_common.h" -static const size_t kai_nr = 2; -static const size_t kai_kr = 2; -static const size_t kai_sr = 1; -static const size_t kai_num_bytes_data = 2; -static const size_t kai_num_bytes_bias = 2; +enum { + NR = 2, + KR = 2, + SR = 1, + NUM_BYTES_DATA = 2, + NUM_BYTES_BIAS = 2, + MAX_BLOCK_HEIGHT = (NR * (KAI_SME_VEC_LENGTH_MAX_BYTES / NUM_BYTES_DATA) / KR), +}; + +void kai_kernel_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme( + size_t height, size_t width, const void* in, void* out, const void* bias); -static size_t kai_get_block_height_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme(void) { - const size_t block_height = kai_nr * kai_get_sme_vector_length_u16() / kai_kr; +static size_t get_block_height(void) { + const size_t block_height = NR * kai_get_sme_vector_length_u16() / KR; return block_height; } size_t kai_get_n_step_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme(void) { - return kai_get_block_height_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme(); + return get_block_height(); } size_t kai_get_rhs_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme(size_t n_idx, size_t rhs_stride) { - KAI_ASSUME(n_idx % kai_get_block_height_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme() == 0); + KAI_ASSUME(n_idx % get_block_height() == 0); return n_idx * rhs_stride; } size_t kai_get_bias_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme(size_t n_idx) { - KAI_ASSUME(n_idx % kai_get_block_height_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme() == 0); + KAI_ASSUME(n_idx % get_block_height() == 0); - return n_idx * kai_num_bytes_bias; + return n_idx * NUM_BYTES_BIAS; } size_t kai_get_rhs_packed_stride_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme(size_t k) { - return kai_num_bytes_bias + kai_roundup(k, kai_kr) * kai_num_bytes_data; + return NUM_BYTES_BIAS + kai_roundup(k, KR) * NUM_BYTES_DATA; } size_t kai_get_rhs_packed_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme(size_t n_idx, size_t k) { - KAI_ASSUME(n_idx % kai_get_block_height_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme() == 0); + KAI_ASSUME(n_idx % get_block_height() == 0); return n_idx * kai_get_rhs_packed_stride_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme(k); } size_t kai_get_rhs_packed_size_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme(size_t n, size_t k) { - return kai_roundup(n, kai_get_block_height_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme()) * - kai_get_rhs_packed_stride_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme(k); + return kai_roundup(n, get_block_height()) * kai_get_rhs_packed_stride_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme(k); } void kai_run_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme( size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params) { KAI_ASSUME(num_groups == 1); - KAI_ASSUME(nr == kai_get_block_height_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme()); - KAI_ASSUME(kr == kai_kr); - KAI_ASSUME(sr == kai_sr); + KAI_ASSUME(nr == get_block_height()); + KAI_ASSUME(kr == KR); + KAI_ASSUME(sr == SR); KAI_ASSUME(rhs != NULL); KAI_ASSUME(bias != NULL); KAI_ASSUME(scale == NULL); @@ -74,294 +76,27 @@ void kai_run_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme( KAI_ASSUME(extra_bytes == 0); KAI_ASSUME(params == NULL); - const size_t block_height = kai_get_block_height_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme(); + const size_t block_height = get_block_height(); const size_t width = k; - const size_t row_offset = 0; - const uint8_t* in[block_height]; + const uint8_t* in[MAX_BLOCK_HEIGHT]; + uint8_t* rhs_packed_ptr = rhs_packed; + const uint8_t* rhs_ptr = rhs; + const uint8_t* bias_ptr = bias; for (size_t block_y = 0; block_y < n; block_y += block_height) { const size_t height = KAI_MIN(n - block_y, block_height); - uint8_t* out = - (uint8_t*)rhs_packed + block_y * (kai_num_bytes_bias + kai_roundup(k, kai_kr) * kai_num_bytes_data); + uint8_t* out = rhs_packed_ptr + block_y * (NUM_BYTES_BIAS + kai_roundup(k, KR) * NUM_BYTES_DATA); for (size_t y = 0; y < height; y++) { - in[y] = (const uint8_t*)rhs + (block_y + y) * rhs_stride; + in[y] = rhs_ptr + (block_y + y) * rhs_stride; } - __asm__ __volatile__( - ".inst 0xd503477f // SMSTART ZA\n" - "ptrue p1.b\n" - "cbz %x[bias], 1f\n" - "whilelt p0.h, XZR, %x[height]\n" - "ld1h { z16.h }, p0/Z, [%x[bias]]\n" - "st1h { z16.h }, p1, [%x[out]]\n" - "addvl %x[out], %x[out], #1\n" - "1:" // Bias: Done - "cnth x21\n" - "mov x22, %x[width]\n" - "inch x22\n" - "mov x20, %x[width]\n" - "sub x7, x21, #0x1\n" - "sub x22, x22, #0x1\n" - "ands x7, x20, x7\n" - "cntw x8\n" - "udiv x22, x22, x21\n" // n_passes = ceildiv(width, VL) - "csel x7, x7, x21, NE\n" - "sub x13, x22, #0x1\n" - "add x7, x7, #0x1\n" - "sub x17, x8, #0x2\n" - "lsl x21, %x[height], #0x1\n" // height * 2 - "lsl x20, x8, #0x1\n" - "mov x16, #0x0\n" - "mov x11, %x[in]\n" - "add x10, %x[in], x8, LSL #3\n" - "cntw x9, ALL, MUL #2\n" - "cntw x28, ALL, MUL #3\n" - "ldr x27, [x11, #0x0]\n" - "lsr x13, x13, #0x1\n" // n_loops = (n_passes - 1) / 2 - "and x26, x22, #0x1\n" // odd_tail = bool(n_passes & 0x1) - "ldr x25, [x10, #0x0]\n" - "lsr x7, x7, #0x1\n" - "ptrue p12.s\n" - "ldr x24, [x11, #0x8]\n" - "whilelt p11.h, XZR, x21\n" - "whilelt p10.h, x20, x21\n" - "ldr x21, [x10, #0x8]\n" - "mov x23, %x[row_offset]\n" - "mov x22, %x[out]\n" - "whilelt p9.h, x16, %x[width]\n" - "whilelt p8.h, x16, %x[width]\n" - "add x11, x11, #0x10\n" - "add x10, x10, #0x10\n" - "mov x12, #0x0\n" - "cbz x17, 3f\n" - "2:" // K loop: Charge: Loop - ".inst 0x25286163 // psel p3.h, p8.h/Z, p11.h[w12]\n" - ".inst 0x25286142 // psel p2.h, p8.h/Z, p10.h[w12]\n" - ".inst 0x25686161 // psel p1.h, p8.h/Z, p11.h[w12, #2]\n" - ".inst 0x25686140 // psel p0.h, p8.h/Z, p10.h[w12, #2]\n" - ".inst 0xe0570f60 // ld1h { za0h.h[x12] }, p3/Z, [x27, x23, LSL #1]\n" - "ldr x27, [x11, #0x0]\n" - ".inst 0xe0570b28 // ld1h { za1h.h[x12] }, p2/Z, [x25, x23, LSL #1]\n" - "ldr x25, [x10, #0x0]\n" - ".inst 0xe0570702 // ld1h { za0h.h[x12, #2] }, p1/Z, [x24, x23, LSL #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe05702aa // ld1h { za1h.h[x12, #2] }, p0/Z, [x21, x23, LSL #1]\n" - "add x12, x12, #0x4\n" - "ldr x21, [x10, #0x8]\n" - "add x10, x10, #0x10\n" - "cmp x12, x17, LSL #1\n" - "blt 2b\n" - "3:" // K loop: Charge: End - ".inst 0x25286163 // psel p3.h, p8.h/Z, p11.h[w12]\n" - ".inst 0x25286142 // psel p2.h, p8.h/Z, p10.h[w12]\n" - ".inst 0x25686161 // psel p1.h, p8.h/Z, p11.h[w12, #2]\n" - ".inst 0x25686140 // psel p0.h, p8.h/Z, p10.h[w12, #2]\n" - "mov x11, %x[in]\n" - "add x10, %x[in], x8, LSL #3\n" - ".inst 0xe0570f60 // ld1h { za0h.h[x12] }, p3/Z, [x27, x23, LSL #1]\n" - "ldr x27, [x11, #0x0]\n" - "inch x16\n" - ".inst 0xe0570b28 // ld1h { za1h.h[x12] }, p2/Z, [x25, x23, LSL #1]\n" - "ldr x25, [x10, #0x0]\n" - ".inst 0xe0570702 // ld1h { za0h.h[x12, #2] }, p1/Z, [x24, x23, LSL #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe05702aa // ld1h { za1h.h[x12, #2] }, p0/Z, [x21, x23, LSL #1]\n" - "ldr x21, [x10, #0x8]\n" - "add x10, x10, #0x10\n" - "inch x23\n" - "cbz x13, 9f\n" - "mov x20, x13\n" - "4:" // K loop: Main loop - "whilelt p8.h, x16, %x[width]\n" - "mov x15, #0x0\n" - "mov x14, #0x0\n" - "cbz x17, 6f\n" - "5:" // K loop: Main loop: First: Loop - ".inst 0x253b6160 // psel p0.h, p8.h/Z, p11.h[w15, #1]\n" - ".inst 0x253b6142 // psel p2.h, p8.h/Z, p10.h[w15, #1]\n" - ".inst 0x257b6161 // psel p1.h, p8.h/Z, p11.h[w15, #3]\n" - ".inst 0x257b6143 // psel p3.h, p8.h/Z, p10.h[w15, #3]\n" - ".inst 0xe0576361 // ld1h { za0h.h[x15, #1] }, p0/Z, [x27, x23, LSL #1]\n" - ".inst 0x252a7120 // psel p0.h, p12.h/Z, p9.h[w14]\n" - "ldr x27, [x11, #0x0]\n" - ".inst 0xe0576b29 // ld1h { za1h.h[x15, #1] }, p2/Z, [x25, x23, LSL #1]\n" - ".inst 0x252a7122 // psel p2.h, p12.h/Z, p9.h[w14]\n" - "ldr x25, [x10, #0x0]\n" - ".inst 0xe0576703 // ld1h { za0h.h[x15, #3] }, p1/Z, [x24, x23, LSL #1]\n" - ".inst 0x253a7121 // psel p1.h, p12.h/Z, p9.h[w14, #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe0576eab // ld1h { za1h.h[x15, #3] }, p3/Z, [x21, x23, LSL #1]\n" - "ldr x21, [x10, #0x8]\n" - ".inst 0xe0bfc2c0 // st1w { za0v.s[x14] }, p0/Z, [x22, XZR, LSL #2]\n" - ".inst 0x253a7120 // psel p0.h, p12.h/Z, p9.h[w14, #1]\n" - ".inst 0xe0a8cac4 // st1w { za1v.s[x14] }, p2/Z, [x22, x8, LSL #2]\n" - "add x10, x10, #0x10\n" - "add x15, x15, #0x4\n" - ".inst 0xe0a9c6c1 // st1w { za0v.s[x14, #1] }, p1/Z, [x22, x9, LSL #2]\n" - ".inst 0xe0bcc2c5 // st1w { za1v.s[x14, #1] }, p0/Z, [x22, x28, LSL #2]\n" - "add x14, x14, #0x2\n" - "addvl x22, x22, #4\n" - "cmp x14, x17\n" - "blt 5b\n" - "6:" // K loop: Main loop: First: Tail - ".inst 0x253b6160 // psel p0.h, p8.h/Z, p11.h[w15, #1]\n" - ".inst 0x253b6142 // psel p2.h, p8.h/Z, p10.h[w15, #1]\n" - ".inst 0x257b6161 // psel p1.h, p8.h/Z, p11.h[w15, #3]\n" - ".inst 0x257b6143 // psel p3.h, p8.h/Z, p10.h[w15, #3]\n" - "mov x11, %x[in]\n" - "add x10, %x[in], x8, LSL #3\n" - ".inst 0xe0576361 // ld1h { za0h.h[x15, #1] }, p0/Z, [x27, x23, LSL #1]\n" - ".inst 0x252a7120 // psel p0.h, p12.h/Z, p9.h[w14]\n" - "ldr x27, [x11, #0x0]\n" - "mov x13, #0x0\n" - ".inst 0xe0576b29 // ld1h { za1h.h[x15, #1] }, p2/Z, [x25, x23, LSL #1]\n" - ".inst 0x252a7122 // psel p2.h, p12.h/Z, p9.h[w14]\n" - "ldr x25, [x10, #0x0]\n" - "mov x12, #0x0\n" - ".inst 0xe0576703 // ld1h { za0h.h[x15, #3] }, p1/Z, [x24, x23, LSL #1]\n" - ".inst 0x253a7121 // psel p1.h, p12.h/Z, p9.h[w14, #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe0576eab // ld1h { za1h.h[x15, #3] }, p3/Z, [x21, x23, LSL #1]\n" - "ldr x21, [x10, #0x8]\n" - ".inst 0xe0bfc2c0 // st1w { za0v.s[x14] }, p0/Z, [x22, XZR, LSL #2]\n" - ".inst 0x253a7120 // psel p0.h, p12.h/Z, p9.h[w14, #1]\n" - ".inst 0xe0a8cac4 // st1w { za1v.s[x14] }, p2/Z, [x22, x8, LSL #2]\n" - "whilelt p9.h, x16, %x[width]\n" - "inch x16\n" - ".inst 0xe0a9c6c1 // st1w { za0v.s[x14, #1] }, p1/Z, [x22, x9, LSL #2]\n" - "add x10, x10, #0x10\n" - "inch x23\n" - ".inst 0xe0bcc2c5 // st1w { za1v.s[x14, #1] }, p0/Z, [x22, x28, LSL #2]\n" - "addvl x22, x22, #4\n" - "whilelt p8.h, x16, %x[width]\n" - "cbz x17, 8f\n" - "7:" // K loop: Main loop: Second: Loop - ".inst 0x25296160 // psel p0.h, p8.h/Z, p11.h[w13]\n" - ".inst 0x25296142 // psel p2.h, p8.h/Z, p10.h[w13]\n" - ".inst 0x25696161 // psel p1.h, p8.h/Z, p11.h[w13, #2]\n" - ".inst 0x25696143 // psel p3.h, p8.h/Z, p10.h[w13, #2]\n" - ".inst 0xe0572360 // ld1h { za0h.h[x13] }, p0/Z, [x27, x23, LSL #1]\n" - ".inst 0x25287120 // psel p0.h, p12.h/Z, p9.h[w12]\n" - "ldr x27, [x11, #0x0]\n" - ".inst 0xe0572b28 // ld1h { za1h.h[x13] }, p2/Z, [x25, x23, LSL #1]\n" - ".inst 0x25287122 // psel p2.h, p12.h/Z, p9.h[w12]\n" - "ldr x25, [x10, #0x0]\n" - ".inst 0xe0572702 // ld1h { za0h.h[x13, #2] }, p1/Z, [x24, x23, LSL #1]\n" - ".inst 0x25387121 // psel p1.h, p12.h/Z, p9.h[w12, #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe0572eaa // ld1h { za1h.h[x13, #2] }, p3/Z, [x21, x23, LSL #1]\n" - "ldr x21, [x10, #0x8]\n" - ".inst 0xe0bf82c8 // st1w { za2v.s[x12] }, p0/Z, [x22, XZR, LSL #2]\n" - ".inst 0x25387120 // psel p0.h, p12.h/Z, p9.h[w12, #1]\n" - ".inst 0xe0a88acc // st1w { za3v.s[x12] }, p2/Z, [x22, x8, LSL #2]\n" - "add x10, x10, #0x10\n" - "add x13, x13, #0x4\n" - ".inst 0xe0a986c9 // st1w { za2v.s[x12, #1] }, p1/Z, [x22, x9, LSL #2]\n" - ".inst 0xe0bc82cd // st1w { za3v.s[x12, #1] }, p0/Z, [x22, x28, LSL #2]\n" - "add x12, x12, #0x2\n" - "addvl x22, x22, #4\n" - "cmp x12, x17\n" - "blt 7b\n" - "8:" // K loop: Main loop: Second: Tail - ".inst 0x25296160 // psel p0.h, p8.h/Z, p11.h[w13]\n" - ".inst 0x25296142 // psel p2.h, p8.h/Z, p10.h[w13]\n" - ".inst 0x25696161 // psel p1.h, p8.h/Z, p11.h[w13, #2]\n" - ".inst 0x25696143 // psel p3.h, p8.h/Z, p10.h[w13, #2]\n" - "mov x11, %x[in]\n" - "add x10, %x[in], x8, LSL #3\n" - ".inst 0xe0572360 // ld1h { za0h.h[x13] }, p0/Z, [x27, x23, LSL #1]\n" - ".inst 0x25287120 // psel p0.h, p12.h/Z, p9.h[w12]\n" - "ldr x27, [x11, #0x0]\n" - ".inst 0xe0572b28 // ld1h { za1h.h[x13] }, p2/Z, [x25, x23, LSL #1]\n" - ".inst 0x25287122 // psel p2.h, p12.h/Z, p9.h[w12]\n" - "ldr x25, [x10, #0x0]\n" - ".inst 0xe0572702 // ld1h { za0h.h[x13, #2] }, p1/Z, [x24, x23, LSL #1]\n" - ".inst 0x25387121 // psel p1.h, p12.h/Z, p9.h[w12, #1]\n" - "ldr x24, [x11, #0x8]\n" - "add x11, x11, #0x10\n" - ".inst 0xe0572eaa // ld1h { za1h.h[x13, #2] }, p3/Z, [x21, x23, LSL #1]\n" - "ldr x21, [x10, #0x8]\n" - ".inst 0xe0bf82c8 // st1w { za2v.s[x12] }, p0/Z, [x22, XZR, LSL #2]\n" - ".inst 0x25387120 // psel p0.h, p12.h/Z, p9.h[w12, #1]\n" - ".inst 0xe0a88acc // st1w { za3v.s[x12] }, p2/Z, [x22, x8, LSL #2]\n" - "whilelt p9.h, x16, %x[width]\n" - "subs x20, x20, #0x1\n" - ".inst 0xe0a986c9 // st1w { za2v.s[x12, #1] }, p1/Z, [x22, x9, LSL #2]\n" - "add x10, x10, #0x10\n" - "inch x16\n" - ".inst 0xe0bc82cd // st1w { za3v.s[x12, #1] }, p0/Z, [x22, x28, LSL #2]\n" - "addvl x22, x22, #4\n" - "inch x23\n" - "bgt 4b\n" - "9:" // K loop: Tails - "cbnz x26, 12f\n" - "mov x11, %x[in]\n" - "whilelt p8.h, x16, %x[width]\n" - "mov x13, #0x0\n" - "mov x12, #0x0\n" - "10:" // K loop: Tails: Even: First - ".inst 0x25307123 // psel p3.s, p12.s/Z, p9.s[w12]\n" - ".inst 0x25307122 // psel p2.s, p12.s/Z, p9.s[w12]\n" - ".inst 0x25396161 // psel p1.h, p8.h/Z, p11.h[w13, #1]\n" - ".inst 0x25396140 // psel p0.h, p8.h/Z, p10.h[w13, #1]\n" - ".inst 0xe0bf8ec0 // st1w { za0v.s[x12] }, p3/Z, [x22, XZR, LSL #2]\n" - ".inst 0xe0a88ac4 // st1w { za1v.s[x12] }, p2/Z, [x22, x8, LSL #2]\n" - "add x12, x12, #0x1\n" - "addvl x22, x22, #2\n" - "ldr x21, [x11, #0x0]\n" - "cmp x12, x8\n" - "ldr x20, [x11, x8, LSL #0x3]\n" - "add x11, x11, #0x8\n" - ".inst 0xe05726a1 // ld1h { za0h.h[x13, #1] }, p1/Z, [x21, x23, LSL #1]\n" - ".inst 0xe0572289 // ld1h { za1h.h[x13, #1] }, p0/Z, [x20, x23, LSL #1]\n" - "add x13, x13, #0x2\n" - "blt 10b\n" - "whilelt p9.h, x16, %x[width]\n" - "whilelt p8.h, x16, %x[width]\n" - "mov x20, #0x0\n" - "mov x12, #0x0\n" - "11:" // K loop: Tails: Even: Second - ".inst 0x25307121 // psel p1.s, p12.s/Z, p9.s[w12]\n" - ".inst 0x25307120 // psel p0.s, p12.s/Z, p9.s[w12]\n" - "add x20, x20, #0x2\n" - ".inst 0xe0bf86c8 // st1w { za2v.s[x12] }, p1/Z, [x22, XZR, LSL #2]\n" - ".inst 0xe0a882cc // st1w { za3v.s[x12] }, p0/Z, [x22, x8, LSL #2]\n" - "add x12, x12, #0x1\n" - "addvl x22, x22, #2\n" - "cmp x12, x7\n" - "blt 11b\n" - "whilelt p8.h, x16, %x[width]\n" - "b 14f\n" - "12:" // K loop: Tails: Odd - "mov x12, #0x0\n" - "13:" // K loop: Tails: Odd: Loop - ".inst 0x25307121 // psel p1.s, p12.s/Z, p9.s[w12]\n" - ".inst 0x25307120 // psel p0.s, p12.s/Z, p9.s[w12]\n" - ".inst 0xe0bf86c0 // st1w { za0v.s[x12] }, p1/Z, [x22, XZR, LSL #2]\n" - ".inst 0xe0a882c4 // st1w { za1v.s[x12] }, p0/Z, [x22, x8, LSL #2]\n" - "add x12, x12, #0x1\n" - "addvl x22, x22, #2\n" - "cmp x12, x7\n" - "blt 13b\n" - "14:" // K loop: End - "mov %x[out], x22\n" - ".inst 0xd503467f // SMSTOP\n" - : [out] "+&r"(out) - : [bias] "r"(bias), [height] "r"(height), [in] "r"(in), [row_offset] "r"(row_offset), [width] "r"(width) - : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", "p7", - "p8", "p9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x20", "x21", "x22", "x23", "x24", - "x25", "x26", "x27", "x28", "x7", "x8", "x9", "z0", "z1", "z10", "z11", "z12", "z13", "z14", "z15", "z16", - "z17", "z18", "z19", "z2", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z3", - "z30", "z31", "z4", "z5", "z6", "z7", "z8", "z9"); + kai_kernel_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme( + height, width, in, out, bias); // NOLINT(bugprone-multi-level-implicit-pointer-conversion) - bias = (const uint8_t*)bias + height * kai_num_bytes_bias; + bias_ptr += height * NUM_BYTES_BIAS; + bias = bias_ptr; } } diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme_asm.S b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme_asm.S new file mode 100644 index 0000000000000000000000000000000000000000..499c3f374e276343605570944935b8ef0928a9c2 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme_asm.S @@ -0,0 +1,328 @@ +// +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(rhs_pack_nxk_x16p2vlx2b_x16_x16_sme) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme) +KAI_ASM_FUNCTION_LABEL(kai_kernel_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme) + stp x20, x21, [sp, -144]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + stp d8, d9, [sp, 72] + stp d10, d11, [sp, 88] + stp d12, d13, [sp, 104] + stp d14, d15, [sp, 120] + KAI_ASM_INST(0xd503477f) // SMSTART ZA + mov x12, #0x0 + ptrue p1.b + cbz x4, label_1 + whilelt p0.h, XZR, x0 + ld1h { z16.h }, p0/Z, [x4] + st1h { z16.h }, p1, [x3] + addvl x3, x3, #1 +KAI_ASM_LABEL(label_1) // Bias: Done + cnth x21 + mov x22, x1 + inch x22 + mov x20, x1 + sub x8, x21, #0x1 + sub x22, x22, #0x1 + ands x8, x20, x8 + cntw x17 + udiv x22, x22, x21 // n_passes = ceildiv(width, VL) + csel x8, x8, x21, NE + sub x13, x22, #0x1 + add x8, x8, #0x1 + sub x16, x17, #0x2 + lsl x21, x0, #0x1 // height * 2 + lsl x20, x17, #0x1 + mov x11, #0x0 + mov x10, x2 + add x9, x2, x17, LSL #3 + cntw x28, ALL, MUL #2 + cntw x27, ALL, MUL #3 + ldr x26, [x10, #0x0] + lsr x13, x13, #0x1 // n_loops = (n_passes - 1) / 2 + and x25, x22, #0x1 // odd_tail = bool(n_passes & 0x1) + ldr x24, [x9, #0x0] + lsr x8, x8, #0x1 + ptrue p12.s + ldr x23, [x10, #0x8] + whilelt p11.h, XZR, x21 + whilelt p10.h, x20, x21 + ldr x21, [x9, #0x8] + mov x22, x12 + whilelt p9.h, x11, x1 + whilelt p8.h, x11, x1 + add x10, x10, #0x10 + add x9, x9, #0x10 + mov x12, #0x0 + cbz x16, label_3 +KAI_ASM_LABEL(label_2) // K loop: Charge: Loop + KAI_ASM_INST(0x25286163) // psel p3.h, p8.h/Z, p11.h[w12] + KAI_ASM_INST(0x25286142) // psel p2.h, p8.h/Z, p10.h[w12] + KAI_ASM_INST(0x25686161) // psel p1.h, p8.h/Z, p11.h[w12, #2] + KAI_ASM_INST(0x25686140) // psel p0.h, p8.h/Z, p10.h[w12, #2] + KAI_ASM_INST(0xe0560f40) // ld1h { za0h.h[x12] }, p3/Z, [x26, x22, LSL #1] + ldr x26, [x10, #0x0] + KAI_ASM_INST(0xe0560b08) // ld1h { za1h.h[x12] }, p2/Z, [x24, x22, LSL #1] + ldr x24, [x9, #0x0] + KAI_ASM_INST(0xe05606e2) // ld1h { za0h.h[x12, #2] }, p1/Z, [x23, x22, LSL #1] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe05602aa) // ld1h { za1h.h[x12, #2] }, p0/Z, [x21, x22, LSL #1] + add x12, x12, #0x4 + ldr x21, [x9, #0x8] + add x9, x9, #0x10 + cmp x12, x16, LSL #1 + blt label_2 +KAI_ASM_LABEL(label_3) // K loop: Charge: End + KAI_ASM_INST(0x25286163) // psel p3.h, p8.h/Z, p11.h[w12] + KAI_ASM_INST(0x25286142) // psel p2.h, p8.h/Z, p10.h[w12] + KAI_ASM_INST(0x25686161) // psel p1.h, p8.h/Z, p11.h[w12, #2] + KAI_ASM_INST(0x25686140) // psel p0.h, p8.h/Z, p10.h[w12, #2] + mov x10, x2 + add x9, x2, x17, LSL #3 + KAI_ASM_INST(0xe0560f40) // ld1h { za0h.h[x12] }, p3/Z, [x26, x22, LSL #1] + ldr x26, [x10, #0x0] + inch x11 + KAI_ASM_INST(0xe0560b08) // ld1h { za1h.h[x12] }, p2/Z, [x24, x22, LSL #1] + ldr x24, [x9, #0x0] + KAI_ASM_INST(0xe05606e2) // ld1h { za0h.h[x12, #2] }, p1/Z, [x23, x22, LSL #1] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe05602aa) // ld1h { za1h.h[x12, #2] }, p0/Z, [x21, x22, LSL #1] + ldr x21, [x9, #0x8] + add x9, x9, #0x10 + inch x22 + cbz x13, label_9 + mov x20, x13 +KAI_ASM_LABEL(label_4) // K loop: Main loop + whilelt p8.h, x11, x1 + mov x15, #0x0 + mov x14, #0x0 + cbz x16, label_6 +KAI_ASM_LABEL(label_5) // K loop: Main loop: First: Loop + KAI_ASM_INST(0x253b6160) // psel p0.h, p8.h/Z, p11.h[w15, #1] + KAI_ASM_INST(0x253b6142) // psel p2.h, p8.h/Z, p10.h[w15, #1] + KAI_ASM_INST(0x257b6161) // psel p1.h, p8.h/Z, p11.h[w15, #3] + KAI_ASM_INST(0x257b6143) // psel p3.h, p8.h/Z, p10.h[w15, #3] + KAI_ASM_INST(0xe0566341) // ld1h { za0h.h[x15, #1] }, p0/Z, [x26, x22, LSL #1] + KAI_ASM_INST(0x252a7120) // psel p0.h, p12.h/Z, p9.h[w14] + ldr x26, [x10, #0x0] + KAI_ASM_INST(0xe0566b09) // ld1h { za1h.h[x15, #1] }, p2/Z, [x24, x22, LSL #1] + KAI_ASM_INST(0x252a7122) // psel p2.h, p12.h/Z, p9.h[w14] + ldr x24, [x9, #0x0] + KAI_ASM_INST(0xe05666e3) // ld1h { za0h.h[x15, #3] }, p1/Z, [x23, x22, LSL #1] + KAI_ASM_INST(0x253a7121) // psel p1.h, p12.h/Z, p9.h[w14, #1] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe0566eab) // ld1h { za1h.h[x15, #3] }, p3/Z, [x21, x22, LSL #1] + ldr x21, [x9, #0x8] + KAI_ASM_INST(0xe0bfc060) // st1w { za0v.s[x14] }, p0/Z, [x3, XZR, LSL #2] + KAI_ASM_INST(0x253a7120) // psel p0.h, p12.h/Z, p9.h[w14, #1] + KAI_ASM_INST(0xe0b1c864) // st1w { za1v.s[x14] }, p2/Z, [x3, x17, LSL #2] + add x9, x9, #0x10 + add x15, x15, #0x4 + KAI_ASM_INST(0xe0bcc461) // st1w { za0v.s[x14, #1] }, p1/Z, [x3, x28, LSL #2] + KAI_ASM_INST(0xe0bbc065) // st1w { za1v.s[x14, #1] }, p0/Z, [x3, x27, LSL #2] + add x14, x14, #0x2 + addvl x3, x3, #4 + cmp x14, x16 + blt label_5 +KAI_ASM_LABEL(label_6) // K loop: Main loop: First: Tail + KAI_ASM_INST(0x253b6160) // psel p0.h, p8.h/Z, p11.h[w15, #1] + KAI_ASM_INST(0x253b6142) // psel p2.h, p8.h/Z, p10.h[w15, #1] + KAI_ASM_INST(0x257b6161) // psel p1.h, p8.h/Z, p11.h[w15, #3] + KAI_ASM_INST(0x257b6143) // psel p3.h, p8.h/Z, p10.h[w15, #3] + mov x10, x2 + add x9, x2, x17, LSL #3 + KAI_ASM_INST(0xe0566341) // ld1h { za0h.h[x15, #1] }, p0/Z, [x26, x22, LSL #1] + KAI_ASM_INST(0x252a7120) // psel p0.h, p12.h/Z, p9.h[w14] + ldr x26, [x10, #0x0] + mov x13, #0x0 + KAI_ASM_INST(0xe0566b09) // ld1h { za1h.h[x15, #1] }, p2/Z, [x24, x22, LSL #1] + KAI_ASM_INST(0x252a7122) // psel p2.h, p12.h/Z, p9.h[w14] + ldr x24, [x9, #0x0] + mov x12, #0x0 + KAI_ASM_INST(0xe05666e3) // ld1h { za0h.h[x15, #3] }, p1/Z, [x23, x22, LSL #1] + KAI_ASM_INST(0x253a7121) // psel p1.h, p12.h/Z, p9.h[w14, #1] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe0566eab) // ld1h { za1h.h[x15, #3] }, p3/Z, [x21, x22, LSL #1] + ldr x21, [x9, #0x8] + KAI_ASM_INST(0xe0bfc060) // st1w { za0v.s[x14] }, p0/Z, [x3, XZR, LSL #2] + KAI_ASM_INST(0x253a7120) // psel p0.h, p12.h/Z, p9.h[w14, #1] + KAI_ASM_INST(0xe0b1c864) // st1w { za1v.s[x14] }, p2/Z, [x3, x17, LSL #2] + whilelt p9.h, x11, x1 + inch x11 + KAI_ASM_INST(0xe0bcc461) // st1w { za0v.s[x14, #1] }, p1/Z, [x3, x28, LSL #2] + add x9, x9, #0x10 + inch x22 + KAI_ASM_INST(0xe0bbc065) // st1w { za1v.s[x14, #1] }, p0/Z, [x3, x27, LSL #2] + addvl x3, x3, #4 + whilelt p8.h, x11, x1 + cbz x16, label_8 +KAI_ASM_LABEL(label_7) // K loop: Main loop: Second: Loop + KAI_ASM_INST(0x25296160) // psel p0.h, p8.h/Z, p11.h[w13] + KAI_ASM_INST(0x25296142) // psel p2.h, p8.h/Z, p10.h[w13] + KAI_ASM_INST(0x25696161) // psel p1.h, p8.h/Z, p11.h[w13, #2] + KAI_ASM_INST(0x25696143) // psel p3.h, p8.h/Z, p10.h[w13, #2] + KAI_ASM_INST(0xe0562340) // ld1h { za0h.h[x13] }, p0/Z, [x26, x22, LSL #1] + KAI_ASM_INST(0x25287120) // psel p0.h, p12.h/Z, p9.h[w12] + ldr x26, [x10, #0x0] + KAI_ASM_INST(0xe0562b08) // ld1h { za1h.h[x13] }, p2/Z, [x24, x22, LSL #1] + KAI_ASM_INST(0x25287122) // psel p2.h, p12.h/Z, p9.h[w12] + ldr x24, [x9, #0x0] + KAI_ASM_INST(0xe05626e2) // ld1h { za0h.h[x13, #2] }, p1/Z, [x23, x22, LSL #1] + KAI_ASM_INST(0x25387121) // psel p1.h, p12.h/Z, p9.h[w12, #1] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe0562eaa) // ld1h { za1h.h[x13, #2] }, p3/Z, [x21, x22, LSL #1] + ldr x21, [x9, #0x8] + KAI_ASM_INST(0xe0bf8068) // st1w { za2v.s[x12] }, p0/Z, [x3, XZR, LSL #2] + KAI_ASM_INST(0x25387120) // psel p0.h, p12.h/Z, p9.h[w12, #1] + KAI_ASM_INST(0xe0b1886c) // st1w { za3v.s[x12] }, p2/Z, [x3, x17, LSL #2] + add x9, x9, #0x10 + add x13, x13, #0x4 + KAI_ASM_INST(0xe0bc8469) // st1w { za2v.s[x12, #1] }, p1/Z, [x3, x28, LSL #2] + KAI_ASM_INST(0xe0bb806d) // st1w { za3v.s[x12, #1] }, p0/Z, [x3, x27, LSL #2] + add x12, x12, #0x2 + addvl x3, x3, #4 + cmp x12, x16 + blt label_7 +KAI_ASM_LABEL(label_8) // K loop: Main loop: Second: Tail + KAI_ASM_INST(0x25296160) // psel p0.h, p8.h/Z, p11.h[w13] + KAI_ASM_INST(0x25296142) // psel p2.h, p8.h/Z, p10.h[w13] + KAI_ASM_INST(0x25696161) // psel p1.h, p8.h/Z, p11.h[w13, #2] + KAI_ASM_INST(0x25696143) // psel p3.h, p8.h/Z, p10.h[w13, #2] + mov x10, x2 + add x9, x2, x17, LSL #3 + KAI_ASM_INST(0xe0562340) // ld1h { za0h.h[x13] }, p0/Z, [x26, x22, LSL #1] + KAI_ASM_INST(0x25287120) // psel p0.h, p12.h/Z, p9.h[w12] + ldr x26, [x10, #0x0] + KAI_ASM_INST(0xe0562b08) // ld1h { za1h.h[x13] }, p2/Z, [x24, x22, LSL #1] + KAI_ASM_INST(0x25287122) // psel p2.h, p12.h/Z, p9.h[w12] + ldr x24, [x9, #0x0] + KAI_ASM_INST(0xe05626e2) // ld1h { za0h.h[x13, #2] }, p1/Z, [x23, x22, LSL #1] + KAI_ASM_INST(0x25387121) // psel p1.h, p12.h/Z, p9.h[w12, #1] + ldr x23, [x10, #0x8] + add x10, x10, #0x10 + KAI_ASM_INST(0xe0562eaa) // ld1h { za1h.h[x13, #2] }, p3/Z, [x21, x22, LSL #1] + ldr x21, [x9, #0x8] + KAI_ASM_INST(0xe0bf8068) // st1w { za2v.s[x12] }, p0/Z, [x3, XZR, LSL #2] + KAI_ASM_INST(0x25387120) // psel p0.h, p12.h/Z, p9.h[w12, #1] + KAI_ASM_INST(0xe0b1886c) // st1w { za3v.s[x12] }, p2/Z, [x3, x17, LSL #2] + whilelt p9.h, x11, x1 + subs x20, x20, #0x1 + KAI_ASM_INST(0xe0bc8469) // st1w { za2v.s[x12, #1] }, p1/Z, [x3, x28, LSL #2] + add x9, x9, #0x10 + inch x11 + KAI_ASM_INST(0xe0bb806d) // st1w { za3v.s[x12, #1] }, p0/Z, [x3, x27, LSL #2] + addvl x3, x3, #4 + inch x22 + bgt label_4 +KAI_ASM_LABEL(label_9) // K loop: Tails + cbnz x25, label_12 + mov x10, x2 + whilelt p8.h, x11, x1 + mov x13, #0x0 + mov x12, #0x0 +KAI_ASM_LABEL(label_10) // K loop: Tails: Even: First + KAI_ASM_INST(0x25307123) // psel p3.s, p12.s/Z, p9.s[w12] + KAI_ASM_INST(0x25307122) // psel p2.s, p12.s/Z, p9.s[w12] + KAI_ASM_INST(0x25396161) // psel p1.h, p8.h/Z, p11.h[w13, #1] + KAI_ASM_INST(0x25396140) // psel p0.h, p8.h/Z, p10.h[w13, #1] + KAI_ASM_INST(0xe0bf8c60) // st1w { za0v.s[x12] }, p3/Z, [x3, XZR, LSL #2] + KAI_ASM_INST(0xe0b18864) // st1w { za1v.s[x12] }, p2/Z, [x3, x17, LSL #2] + add x12, x12, #0x1 + addvl x3, x3, #2 + ldr x21, [x10, #0x0] + cmp x12, x17 + ldr x20, [x10, x17, LSL #0x3] + add x10, x10, #0x8 + KAI_ASM_INST(0xe05626a1) // ld1h { za0h.h[x13, #1] }, p1/Z, [x21, x22, LSL #1] + KAI_ASM_INST(0xe0562289) // ld1h { za1h.h[x13, #1] }, p0/Z, [x20, x22, LSL #1] + add x13, x13, #0x2 + blt label_10 + whilelt p9.h, x11, x1 + whilelt p8.h, x11, x1 + mov x20, #0x0 + mov x12, #0x0 +KAI_ASM_LABEL(label_11) // K loop: Tails: Even: Second + KAI_ASM_INST(0x25307121) // psel p1.s, p12.s/Z, p9.s[w12] + KAI_ASM_INST(0x25307120) // psel p0.s, p12.s/Z, p9.s[w12] + add x20, x20, #0x2 + KAI_ASM_INST(0xe0bf8468) // st1w { za2v.s[x12] }, p1/Z, [x3, XZR, LSL #2] + KAI_ASM_INST(0xe0b1806c) // st1w { za3v.s[x12] }, p0/Z, [x3, x17, LSL #2] + add x12, x12, #0x1 + addvl x3, x3, #2 + cmp x12, x8 + blt label_11 + whilelt p8.h, x11, x1 + b label_14 +KAI_ASM_LABEL(label_12) // K loop: Tails: Odd + mov x12, #0x0 +KAI_ASM_LABEL(label_13) // K loop: Tails: Odd: Loop + KAI_ASM_INST(0x25307121) // psel p1.s, p12.s/Z, p9.s[w12] + KAI_ASM_INST(0x25307120) // psel p0.s, p12.s/Z, p9.s[w12] + KAI_ASM_INST(0xe0bf8460) // st1w { za0v.s[x12] }, p1/Z, [x3, XZR, LSL #2] + KAI_ASM_INST(0xe0b18064) // st1w { za1v.s[x12] }, p0/Z, [x3, x17, LSL #2] + add x12, x12, #0x1 + addvl x3, x3, #2 + cmp x12, x8 + blt label_13 +KAI_ASM_LABEL(label_14) // K loop: End + KAI_ASM_INST(0xd503467f) // SMSTOP + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp d8, d9, [sp, 72] + ldp d10, d11, [sp, 88] + ldp d12, d13, [sp, 104] + ldp d14, d15, [sp, 120] + ldp x20, x21, [sp], 144 + ret + KAI_ASM_FUNCTION_END(kai_kernel_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme) + + KAI_ASM_END