diff --git a/CMakeLists.txt b/CMakeLists.txt index 649e158504fa3b9ba96bc4a6be6f7cdee0944aac..b6dc8410738a0cca08c4976ff979f5a25fee4621 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -122,6 +122,7 @@ set(KLEIDIAI_FILES_NEON_I8MM set(KLEIDIAI_FILES_SME kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.c + kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.c kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32pb_f32_f32_16vlx1_sme.c ) diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 66a3c3868a97e70828efb60ef3f84d724c0e8ef2..35267f4248ad81bb42d9ddd09642578ee6216051 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -211,6 +211,13 @@ kai_c_library( cpu_uarch = kai_cpu_sme(), ) +kai_c_library( + name = "rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme", + srcs = ["pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.c"], + hdrs = ["pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.h"], + cpu_uarch = kai_cpu_sme(), +) + kai_c_library( name = "rhs_pack_kxn_f32pb_f32_f32_16vlx1_sme", srcs = ["pack/kai_rhs_pack_kxn_f32pb_f32_f32_16vlx1_sme.c"], @@ -368,6 +375,7 @@ kai_c_library( ":rhs_pack_kxn_f32pbiasf32_f32_f32_neon", ":rhs_pack_kxn_qsi4c32p_qsu4c32s1s0", ":rhs_pack_kxn_qsi4cxp_qs4cxs1s0", + ":rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme", ":rhs_pack_nxk_qsi4c32p_qsu4c32s1s0", ":rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0", ":rhs_pack_nxk_qsi4cxp_qs4cxs1s0", diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.c new file mode 100644 index 0000000000000000000000000000000000000000..026baf0f9654c40377ea156126ed42bb707534d6 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.c @@ -0,0 +1,353 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural features check. + +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_nr = 2; +static const size_t kai_kr = 1; +static const size_t kai_sr = 1; +static const size_t kai_num_bytes_data = 4; +static const size_t kai_num_bytes_bias = 4; + +static size_t get_block_height(void) { + const size_t block_height = kai_nr * kai_get_sme_vector_length_u32() / kai_kr; + return block_height; +} + +size_t kai_get_n_step_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(void) { + return get_block_height(); +} + +size_t kai_get_rhs_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(size_t n_idx, size_t rhs_stride) { + KAI_ASSUME(n_idx % get_block_height() == 0); + + return n_idx * rhs_stride; +} + +size_t kai_get_bias_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(size_t n_idx) { + KAI_ASSUME(n_idx % get_block_height() == 0); + + return n_idx * kai_num_bytes_bias; +} + +size_t kai_get_rhs_packed_stride_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(size_t k) { + return kai_num_bytes_bias + kai_roundup(k, kai_kr) * kai_num_bytes_data; +} + +size_t kai_get_rhs_packed_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(size_t n_idx, size_t k) { + KAI_ASSUME(n_idx % get_block_height() == 0); + + return n_idx * kai_get_rhs_packed_stride_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(k); +} + +size_t kai_get_rhs_packed_size_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(size_t n, size_t k) { + return kai_roundup(n, get_block_height()) * kai_get_rhs_packed_stride_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(k); +} + +void kai_run_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, + const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params) { + KAI_ASSUME(num_groups == 1); + KAI_ASSUME(nr == get_block_height()); + KAI_ASSUME(kr == kai_kr); + KAI_ASSUME(sr == kai_sr); + KAI_ASSUME(rhs != NULL); + KAI_ASSUME(bias != NULL); + KAI_ASSUME(scale == NULL); + KAI_ASSUME(rhs_packed != NULL); + KAI_ASSUME(extra_bytes == 0); + KAI_ASSUME(params == NULL); + + const size_t block_height = get_block_height(); + const size_t width = k; + const size_t row_offset = 0; + + const void* in[block_height]; + + for (size_t block_y = 0; block_y < n; block_y += block_height) { + const size_t height = KAI_MIN(n - block_y, block_height); + void* out = rhs_packed + block_y * (kai_num_bytes_bias + kai_roundup(k, kai_kr) * kai_num_bytes_data); + + for (size_t y = 0; y < height; y++) { + in[y] = rhs + (block_y + y) * rhs_stride; + } + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "ptrue p1.b\n" + "cbz %x[bias], 1f\n" + "mov x20, %x[height]\n" + "whilelt p0.s, XZR, %x[height]\n" + "decw x20\n" + "ld1w { z16.s }, p0/Z, [%x[bias]]\n" + "whilelt p0.s, XZR, x20\n" + "st1w { z16.s }, p1, [%x[out]]\n" + "ld1w { z16.s }, p0/Z, [%x[bias], #1, MUL VL]\n" + "st1w { z16.s }, p1, [%x[out], #1, MUL VL]\n" + "addvl %x[out], %x[out], #2\n" + "1:" // Bias: Done + "mov x21, %x[width]\n" + "cntw x17\n" + "incw x21\n" + "mov x20, %x[width]\n" + "sub x21, x21, #0x1\n" + "sub x16, x17, #0x1\n" + "udiv x21, x21, x17\n" // n_passes = ceildiv(width, VL) + "ands x16, x20, x16\n" + "sub x20, x21, #0x1\n" + "sub x15, x17, #0x2\n" + "mov x14, #0x0\n" + "mov x11, %x[in]\n" + "add x10, %x[in], x17, LSL #3\n" + "cntw x9, ALL, MUL #2\n" + "ldr x28, [x11, #0x0]\n" + "cntw x27, ALL, MUL #3\n" + "lsr x20, x20, #0x1\n" // n_loops = (n_passes - 1) / 2 + "ldr x26, [x10, #0x0]\n" + "and x25, x21, #0x1\n" // odd_tail = bool(n_passes & 0x1) + "csel x16, x16, x17, NE\n" + "ldr x24, [x11, #0x8]\n" + "ptrue p12.s\n" + "whilelt p11.s, XZR, %x[height]\n" + "ldr x21, [x10, #0x8]\n" + "whilelt p10.s, x17, %x[height]\n" + "mov x23, %x[row_offset]\n" + "mov x22, %x[out]\n" + "whilelt p9.s, x14, %x[width]\n" + "whilelt p8.s, x14, %x[width]\n" + "add x11, x11, #0x10\n" + "add x10, x10, #0x10\n" + "mov x12, #0x0\n" + "cbz x15, 3f\n" + "2:" // K loop: Charge: Loop + ".inst 0x25306163 // psel p3.s, p8.s/Z, p11.s[w12]\n" + ".inst 0x25306142 // psel p2.s, p8.s/Z, p10.s[w12]\n" + ".inst 0x25706161 // psel p1.s, p8.s/Z, p11.s[w12, #1]\n" + ".inst 0x25706140 // psel p0.s, p8.s/Z, p10.s[w12, #1]\n" + ".inst 0xe0970f80 // ld1w { za0h.s[x12] }, p3/Z, [x28, x23, LSL #2]\n" + "ldr x28, [x11, #0x0]\n" + ".inst 0xe0970b44 // ld1w { za1h.s[x12] }, p2/Z, [x26, x23, LSL #2]\n" + "ldr x26, [x10, #0x0]\n" + ".inst 0xe0970701 // ld1w { za0h.s[x12, #1] }, p1/Z, [x24, x23, LSL #2]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe09702a5 // ld1w { za1h.s[x12, #1] }, p0/Z, [x21, x23, LSL #2]\n" + "add x12, x12, #0x2\n" + "ldr x21, [x10, #0x8]\n" + "add x10, x10, #0x10\n" + "cmp x12, x15\n" + "blt 2b\n" + "3:" // K loop: Charge: End + ".inst 0x25306163 // psel p3.s, p8.s/Z, p11.s[w12]\n" + ".inst 0x25306142 // psel p2.s, p8.s/Z, p10.s[w12]\n" + ".inst 0x25706161 // psel p1.s, p8.s/Z, p11.s[w12, #1]\n" + ".inst 0x25706140 // psel p0.s, p8.s/Z, p10.s[w12, #1]\n" + "mov x11, %x[in]\n" + "add x10, %x[in], x17, LSL #3\n" + ".inst 0xe0970f80 // ld1w { za0h.s[x12] }, p3/Z, [x28, x23, LSL #2]\n" + "ldr x28, [x11, #0x0]\n" + "incw x14\n" + ".inst 0xe0970b44 // ld1w { za1h.s[x12] }, p2/Z, [x26, x23, LSL #2]\n" + "ldr x26, [x10, #0x0]\n" + ".inst 0xe0970701 // ld1w { za0h.s[x12, #1] }, p1/Z, [x24, x23, LSL #2]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe09702a5 // ld1w { za1h.s[x12, #1] }, p0/Z, [x21, x23, LSL #2]\n" + "ldr x21, [x10, #0x8]\n" + "add x10, x10, #0x10\n" + "incw x23\n" + "cbz x20, 9f\n" + "mov x20, x20\n" + "4:" // K loop: Main loop + "whilelt p8.s, x14, %x[width]\n" + "mov x13, #0x0\n" + "cbz x15, 6f\n" + "5:" // K loop: Main loop: First: Loop + ".inst 0x25316160 // psel p0.s, p8.s/Z, p11.s[w13]\n" + ".inst 0x25316142 // psel p2.s, p8.s/Z, p10.s[w13]\n" + ".inst 0x25716161 // psel p1.s, p8.s/Z, p11.s[w13, #1]\n" + ".inst 0x25716143 // psel p3.s, p8.s/Z, p10.s[w13, #1]\n" + ".inst 0xe0972388 // ld1w { za2h.s[x13] }, p0/Z, [x28, x23, LSL #2]\n" + ".inst 0x25317120 // psel p0.s, p12.s/Z, p9.s[w13]\n" + "ldr x28, [x11, #0x0]\n" + ".inst 0xe0972b4c // ld1w { za3h.s[x13] }, p2/Z, [x26, x23, LSL #2]\n" + ".inst 0x25317122 // psel p2.s, p12.s/Z, p9.s[w13]\n" + "ldr x26, [x10, #0x0]\n" + ".inst 0xe0972709 // ld1w { za2h.s[x13, #1] }, p1/Z, [x24, x23, LSL #2]\n" + ".inst 0x25717121 // psel p1.s, p12.s/Z, p9.s[w13, #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe0972ead // ld1w { za3h.s[x13, #1] }, p3/Z, [x21, x23, LSL #2]\n" + "ldr x21, [x10, #0x8]\n" + ".inst 0xe0bfa2c0 // st1w { za0v.s[x13] }, p0/Z, [x22, XZR, LSL #2]\n" + ".inst 0x25717120 // psel p0.s, p12.s/Z, p9.s[w13, #1]\n" + ".inst 0xe0b1aac4 // st1w { za1v.s[x13] }, p2/Z, [x22, x17, LSL #2]\n" + "add x10, x10, #0x10\n" + ".inst 0xe0a9a6c1 // st1w { za0v.s[x13, #1] }, p1/Z, [x22, x9, LSL #2]\n" + ".inst 0xe0bba2c5 // st1w { za1v.s[x13, #1] }, p0/Z, [x22, x27, LSL #2]\n" + "add x13, x13, #0x2\n" + "addvl x22, x22, #4\n" + "cmp x13, x15\n" + "blt 5b\n" + "6:" // K loop: Main loop: First: Tail + ".inst 0x25316160 // psel p0.s, p8.s/Z, p11.s[w13]\n" + ".inst 0x25316142 // psel p2.s, p8.s/Z, p10.s[w13]\n" + ".inst 0x25716161 // psel p1.s, p8.s/Z, p11.s[w13, #1]\n" + ".inst 0x25716143 // psel p3.s, p8.s/Z, p10.s[w13, #1]\n" + "mov x11, %x[in]\n" + "add x10, %x[in], x17, LSL #3\n" + ".inst 0xe0972388 // ld1w { za2h.s[x13] }, p0/Z, [x28, x23, LSL #2]\n" + ".inst 0x25317120 // psel p0.s, p12.s/Z, p9.s[w13]\n" + "ldr x28, [x11, #0x0]\n" + "mov x12, #0x0\n" + ".inst 0xe0972b4c // ld1w { za3h.s[x13] }, p2/Z, [x26, x23, LSL #2]\n" + ".inst 0x25317122 // psel p2.s, p12.s/Z, p9.s[w13]\n" + "ldr x26, [x10, #0x0]\n" + ".inst 0xe0972709 // ld1w { za2h.s[x13, #1] }, p1/Z, [x24, x23, LSL #2]\n" + ".inst 0x25717121 // psel p1.s, p12.s/Z, p9.s[w13, #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe0972ead // ld1w { za3h.s[x13, #1] }, p3/Z, [x21, x23, LSL #2]\n" + "ldr x21, [x10, #0x8]\n" + ".inst 0xe0bfa2c0 // st1w { za0v.s[x13] }, p0/Z, [x22, XZR, LSL #2]\n" + ".inst 0x25717120 // psel p0.s, p12.s/Z, p9.s[w13, #1]\n" + ".inst 0xe0b1aac4 // st1w { za1v.s[x13] }, p2/Z, [x22, x17, LSL #2]\n" + "whilelt p9.s, x14, %x[width]\n" + "incw x14\n" + ".inst 0xe0a9a6c1 // st1w { za0v.s[x13, #1] }, p1/Z, [x22, x9, LSL #2]\n" + "add x10, x10, #0x10\n" + "incw x23\n" + ".inst 0xe0bba2c5 // st1w { za1v.s[x13, #1] }, p0/Z, [x22, x27, LSL #2]\n" + "addvl x22, x22, #4\n" + "whilelt p8.s, x14, %x[width]\n" + "cbz x15, 8f\n" + "7:" // K loop: Main loop: Second: Loop + ".inst 0x25306160 // psel p0.s, p8.s/Z, p11.s[w12]\n" + ".inst 0x25306142 // psel p2.s, p8.s/Z, p10.s[w12]\n" + ".inst 0x25706161 // psel p1.s, p8.s/Z, p11.s[w12, #1]\n" + ".inst 0x25706143 // psel p3.s, p8.s/Z, p10.s[w12, #1]\n" + ".inst 0xe0970380 // ld1w { za0h.s[x12] }, p0/Z, [x28, x23, LSL #2]\n" + ".inst 0x25307120 // psel p0.s, p12.s/Z, p9.s[w12]\n" + "ldr x28, [x11, #0x0]\n" + ".inst 0xe0970b44 // ld1w { za1h.s[x12] }, p2/Z, [x26, x23, LSL #2]\n" + ".inst 0x25307122 // psel p2.s, p12.s/Z, p9.s[w12]\n" + "ldr x26, [x10, #0x0]\n" + ".inst 0xe0970701 // ld1w { za0h.s[x12, #1] }, p1/Z, [x24, x23, LSL #2]\n" + ".inst 0x25707121 // psel p1.s, p12.s/Z, p9.s[w12, #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe0970ea5 // ld1w { za1h.s[x12, #1] }, p3/Z, [x21, x23, LSL #2]\n" + "ldr x21, [x10, #0x8]\n" + ".inst 0xe0bf82c8 // st1w { za2v.s[x12] }, p0/Z, [x22, XZR, LSL #2]\n" + ".inst 0x25707120 // psel p0.s, p12.s/Z, p9.s[w12, #1]\n" + ".inst 0xe0b18acc // st1w { za3v.s[x12] }, p2/Z, [x22, x17, LSL #2]\n" + "add x10, x10, #0x10\n" + ".inst 0xe0a986c9 // st1w { za2v.s[x12, #1] }, p1/Z, [x22, x9, LSL #2]\n" + ".inst 0xe0bb82cd // st1w { za3v.s[x12, #1] }, p0/Z, [x22, x27, LSL #2]\n" + "add x12, x12, #0x2\n" + "addvl x22, x22, #4\n" + "cmp x12, x15\n" + "blt 7b\n" + "8:" // K loop: Main loop: Second: Tail + ".inst 0x25306160 // psel p0.s, p8.s/Z, p11.s[w12]\n" + ".inst 0x25306142 // psel p2.s, p8.s/Z, p10.s[w12]\n" + ".inst 0x25706161 // psel p1.s, p8.s/Z, p11.s[w12, #1]\n" + ".inst 0x25706143 // psel p3.s, p8.s/Z, p10.s[w12, #1]\n" + "mov x11, %x[in]\n" + "add x10, %x[in], x17, LSL #3\n" + ".inst 0xe0970380 // ld1w { za0h.s[x12] }, p0/Z, [x28, x23, LSL #2]\n" + ".inst 0x25307120 // psel p0.s, p12.s/Z, p9.s[w12]\n" + "ldr x28, [x11, #0x0]\n" + ".inst 0xe0970b44 // ld1w { za1h.s[x12] }, p2/Z, [x26, x23, LSL #2]\n" + ".inst 0x25307122 // psel p2.s, p12.s/Z, p9.s[w12]\n" + "ldr x26, [x10, #0x0]\n" + ".inst 0xe0970701 // ld1w { za0h.s[x12, #1] }, p1/Z, [x24, x23, LSL #2]\n" + ".inst 0x25707121 // psel p1.s, p12.s/Z, p9.s[w12, #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe0970ea5 // ld1w { za1h.s[x12, #1] }, p3/Z, [x21, x23, LSL #2]\n" + "ldr x21, [x10, #0x8]\n" + ".inst 0xe0bf82c8 // st1w { za2v.s[x12] }, p0/Z, [x22, XZR, LSL #2]\n" + ".inst 0x25707120 // psel p0.s, p12.s/Z, p9.s[w12, #1]\n" + ".inst 0xe0b18acc // st1w { za3v.s[x12] }, p2/Z, [x22, x17, LSL #2]\n" + "whilelt p9.s, x14, %x[width]\n" + "subs x20, x20, #0x1\n" + ".inst 0xe0a986c9 // st1w { za2v.s[x12, #1] }, p1/Z, [x22, x9, LSL #2]\n" + "add x10, x10, #0x10\n" + "incw x14\n" + ".inst 0xe0bb82cd // st1w { za3v.s[x12, #1] }, p0/Z, [x22, x27, LSL #2]\n" + "addvl x22, x22, #4\n" + "incw x23\n" + "bgt 4b\n" + "9:" // K loop: Tails + "cbnz x25, 12f\n" + "mov x11, %x[in]\n" + "whilelt p8.s, x14, %x[width]\n" + "mov x12, #0x0\n" + "10:" // K loop: Tails: Even: First + ".inst 0x25307123 // psel p3.s, p12.s/Z, p9.s[w12]\n" + ".inst 0x25307122 // psel p2.s, p12.s/Z, p9.s[w12]\n" + ".inst 0x25306161 // psel p1.s, p8.s/Z, p11.s[w12]\n" + ".inst 0x25306140 // psel p0.s, p8.s/Z, p10.s[w12]\n" + ".inst 0xe0bf8ec0 // st1w { za0v.s[x12] }, p3/Z, [x22, XZR, LSL #2]\n" + ".inst 0xe0b18ac4 // st1w { za1v.s[x12] }, p2/Z, [x22, x17, LSL #2]\n" + "addvl x22, x22, #2\n" + "ldr x21, [x11, #0x0]\n" + "ldr x20, [x11, x17, LSL #0x3]\n" + "add x11, x11, #0x8\n" + ".inst 0xe09706a8 // ld1w { za2h.s[x12] }, p1/Z, [x21, x23, LSL #2]\n" + ".inst 0xe097028c // ld1w { za3h.s[x12] }, p0/Z, [x20, x23, LSL #2]\n" + "add x12, x12, #0x1\n" + "cmp x12, x17\n" + "blt 10b\n" + "whilelt p9.s, x14, %x[width]\n" + "whilelt p8.s, x14, %x[width]\n" + "mov x12, #0x0\n" + "11:" // K loop: Tails: Even: Second + ".inst 0x25307121 // psel p1.s, p12.s/Z, p9.s[w12]\n" + ".inst 0x25307120 // psel p0.s, p12.s/Z, p9.s[w12]\n" + ".inst 0xe0bf86c8 // st1w { za2v.s[x12] }, p1/Z, [x22, XZR, LSL #2]\n" + ".inst 0xe0b182cc // st1w { za3v.s[x12] }, p0/Z, [x22, x17, LSL #2]\n" + "add x12, x12, #0x1\n" + "addvl x22, x22, #2\n" + "cmp x12, x16\n" + "blt 11b\n" + "whilelt p8.s, x14, %x[width]\n" + "b 14f\n" + "12:" // K loop: Tails: Odd + "mov x12, #0x0\n" + "13:" // K loop: Tails: Odd: Loop + ".inst 0x25307121 // psel p1.s, p12.s/Z, p9.s[w12]\n" + ".inst 0x25307120 // psel p0.s, p12.s/Z, p9.s[w12]\n" + ".inst 0xe0bf86c0 // st1w { za0v.s[x12] }, p1/Z, [x22, XZR, LSL #2]\n" + ".inst 0xe0b182c4 // st1w { za1v.s[x12] }, p0/Z, [x22, x17, LSL #2]\n" + "add x12, x12, #0x1\n" + "addvl x22, x22, #2\n" + "cmp x12, x16\n" + "blt 13b\n" + "14:" // K loop: End + "mov %x[out], x22\n" + ".inst 0xd503467f // SMSTOP\n" + : [out] "+&r"(out) + : [bias] "r"(bias), [height] "r"(height), [in] "r"(in), [row_offset] "r"(row_offset), [width] "r"(width) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", + "p14", "p15", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x20", "x21", "x22", "x23", + "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", + "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", + "z26", "z27", "z28", "z29", "z30", "z31"); + + bias += height * kai_num_bytes_bias; + } +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.h b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.h new file mode 100644 index 0000000000000000000000000000000000000000..8bd3f6ee783104b625724cf524b9e471145ecbe2 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.h @@ -0,0 +1,81 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Gets n step value. +/// +/// The starting row index must be divisible by `n_step`. +/// +/// @return The n step value. +size_t kai_get_n_step_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(void); + +/// Gets the offset in bytes to the data element in the RHS matrix buffer. +/// +/// @param[in] n_idx Column index. +/// @param[in] rhs_offset Row stride in bytes of the RHS matrix. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(size_t n_idx, size_t rhs_offset); + +/// Gets the offset in bytes to the data element in the bias buffer. +/// +/// @param[in] n_idx Column index. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_bias_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(size_t n_idx); + +/// Gets the offset in bytes to the data element in the packed RHS buffer. +/// +/// @param[in] n_idx Row index. +/// @param[in] k Number of columns. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_packed_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(size_t n_idx, size_t k); + +/// Gets the size in bytes of the packed RHS buffer. +/// +/// @param[in] n Number of rows. +/// @param[in] k Number of columns. +/// +/// @return The size in bytes of the packed RHS buffer. +size_t kai_get_rhs_packed_size_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(size_t n, size_t k); + +/// Runs the RHS packing function for matrix multiplication. +/// +/// The pointer of each buffers (RHS, bias and packed RHS) needs to be added with offset +/// calculated using the following functions: +/// +/// * RHS: @ref kai_get_rhs_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme. +/// * Bias: @ref kai_get_bias_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme. +/// * Output: @ref kai_get_rhs_packed_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme. +/// +/// @param[in] num_groups Number of groups. It must be 1. +/// @param[in] n Number of columns of the output matrix. +/// @param[in] k Common dimension between the LHS and RHS matrix. +/// @param[in] nr Block size in N dimension. It must be 2 * kai_get_sme_vector_length_u32(). +/// @param[in] kr Block size in K dimension. It must be 1. +/// @param[in] sr Number of kr splits. It must be 1. +/// @param[in] rhs_stride Row stride in bytes of the RHS matrix. +/// @param[in] rhs RHS matrix data buffer. +/// @param[in] bias Bias matrix data buffer. +/// @param[in] scale Scale data buffer. It must be NULL. +/// @param[out] rhs_packed Packed RHS matrix. +/// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. It must be 0. +/// @param[in] params Extra packing parameters. It must be NULL. +void kai_run_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, + const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/test/common/data_format.cpp b/test/common/data_format.cpp index b4da18ae95cc27fad3d32f2c445daaa5eace11f3..20c47c98712240b41c4ac7ee40c56d7fb13a3f8f 100644 --- a/test/common/data_format.cpp +++ b/test/common/data_format.cpp @@ -154,7 +154,8 @@ uintptr_t DataFormat::default_offset_in_bytes(size_t row, size_t col, size_t wid switch (_pack_format) { case PackFormat::NONE: - return row * row_stride + col * data_type_size_in_bits(_data_type) / 8; + return row * row_stride / (_block_height > 0 ? _block_height : 1) + + col * data_type_size_in_bits(_data_type) / 8; case PackFormat::BIAS_PER_ROW: case PackFormat::QUANTIZE_PER_ROW: diff --git a/test/common/matmul_test_common.hpp b/test/common/matmul_test_common.hpp index 21f6e244c434d6789321c69ddf5b4a2471eec3d2..227dff11de3812fa4abba98325a1c7ff2188867a 100644 --- a/test/common/matmul_test_common.hpp +++ b/test/common/matmul_test_common.hpp @@ -34,9 +34,6 @@ struct MatMulMethod { size_t n0{0}; ///< Block size in N dimension. size_t k0{0}; ///< Block size in K dimension. - bool lhs_transposed; ///< LHS matrix is transposed. - bool rhs_transposed; ///< RHS matrix is transposed. - DataFormat dst_format; ///< Data format of the destination matrix. DataFormat lhs_format; ///< Data format of the LHS matrix. DataFormat packed_lhs_format; ///< Data format of the packed LHS matrix. @@ -192,6 +189,71 @@ struct MatMulMethod { const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params)> fn_pack_rhs; + /// Gets n step value. + /// + /// The starting row index must be divisible by `n_step`. + /// + /// @return The n step value. + std::function fn_pack_rhs_nxk_get_n_step{nullptr}; + + /// Gets the offset in bytes to the data element in the RHS matrix buffer. + /// + /// @param[in] n_idx Column index. + /// @param[in] rhs_offset Row stride in bytes of the RHS matrix. + /// + /// @return The offset in bytes to the data element. + std::function fn_pack_rhs_nxk_get_rhs_offset{nullptr}; + + /// Gets the offset in bytes to the data element in the bias buffer. + /// + /// @param[in] n_idx Column index. + /// + /// @return The offset in bytes to the data element. + std::function fn_pack_rhs_nxk_get_bias_offset{nullptr}; + + /// Gets the offset in bytes to the data element in the packed RHS buffer. + /// + /// @param[in] n_idx Row index. + /// @param[in] k Number of columns. + /// + /// @return The offset in bytes to the data element. + std::function fn_pack_rhs_nxk_get_packed_rhs_offset{nullptr}; + + /// Gets the size in bytes of the packed RHS buffer. + /// + /// @param[in] n Number of rows. + /// @param[in] k Number of columns. + /// + /// @return The size in bytes of the packed RHS buffer. + std::function fn_pack_rhs_nxk_get_packed_rhs_size{nullptr}; + + /// Runs the RHS packing function for matrix multiplication. + /// + /// The pointer of each buffers (RHS, bias and packed RHS) needs to be added with offset + /// calculated using the following functions: + /// + /// * RHS: @ref kai_get_rhs_offset_rhs_pack_nxk_f32p2vlx1b_f32_f32_sme. + /// * Bias: @ref kai_get_bias_offset_rhs_pack_nxk_f32p2vlx1b_f32_f32_sme. + /// * Output: @ref kai_get_rhs_packed_offset_rhs_pack_nxk_f32p2vlx1b_f32_f32_sme. + /// + /// @param[in] num_groups Number of groups. It must be 1. + /// @param[in] n Number of columns of the output matrix. + /// @param[in] k Common dimension between the LHS and RHS matrix. + /// @param[in] nr Block size in N dimension. It must be 2 * kai_get_sme_vector_length_u32(). + /// @param[in] kr Block size in K dimension. It must be 1. + /// @param[in] sr Number of kr splits. It must be 1. + /// @param[in] rhs_stride Row stride in bytes of the RHS matrix. + /// @param[in] rhs RHS matrix data buffer. + /// @param[in] bias Bias matrix data buffer. + /// @param[in] scale Scale data buffer. It must be NULL. + /// @param[out] rhs_packed Packed RHS matrix. + /// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. It must be 0. + /// @param[in] params Extra packing parameters. It must be NULL. + std::function + fn_pack_rhs_nxk{nullptr}; + /// Gets the offset in bytes to the data element in the bias buffer. /// /// @param[in] n_idx Column index. @@ -290,6 +352,11 @@ struct MatMulMethod { return fn_pack_rhs != nullptr; } + /// Gets a value indicating whether pre-processing the transposed RHS matrix is needed. + [[nodiscard]] bool is_pack_rhs_nxk_needed() const { + return fn_pack_rhs_nxk != nullptr; + } + /// Preprocesses the RHS matrix. /// /// @param[in] n Size of the matrix in N dimension. @@ -318,6 +385,35 @@ struct MatMulMethod { } } + /// Preprocesses the transposed RHS matrix. + /// + /// @param[in] n Size of the matrix in N dimension. + /// @param[in] k Size of the matrix in K dimension. + /// @param[in] rhs RHS data buffer. + /// @param[in] rhs_row_stride RHS row stride. + /// @param[in] bias Bias data buffer. + /// @param[in] scale Quantization scales data buffer. + /// @param[out] packed_rhs Packed RHS data buffer. + void pack_rhs_nxk( + size_t n, size_t k, const void* rhs, size_t rhs_row_stride, const void* bias, const void* scale, + void* packed_rhs) const { + KAI_UNUSED(n); + KAI_UNUSED(k); + KAI_UNUSED(rhs); + KAI_UNUSED(rhs_row_stride); + KAI_UNUSED(bias); + KAI_UNUSED(scale); + KAI_UNUSED(packed_rhs); + + if (fn_pack_rhs_nxk != nullptr) { + fn_pack_rhs_nxk( + 1, n, k, fn_get_nr(), fn_get_kr(), fn_get_sr(), rhs_row_stride, rhs, bias, nullptr, packed_rhs, 0, + nullptr); + } else { + KAI_ERROR("RHS pre-processing is not supported!"); + } + } + [[nodiscard]] bool has_main_kernel() const { return fn_matmul_f16_f16_f16p != nullptr || fn_matmul_f32_f32p_f32p != nullptr || fn_matmul_f32_f32_f32p != nullptr || fn_matmul_f32_bf16p_bf16p != nullptr; diff --git a/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp index 730ff5aed286f79ab87bac3014fa9b150000f4c6..ae9cb2ce77dd54bead021b76e499c5b8707c1036 100644 --- a/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp +++ b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp @@ -47,9 +47,6 @@ const std::array matmul_methods = { .n0 = 12, .k0 = 4, - .lhs_transposed = false, - .rhs_transposed = false, - .dst_format = DataFormat(DataType::FP32), .lhs_format = DataFormat(DataType::FP32), .packed_lhs_format = @@ -95,9 +92,6 @@ const std::array matmul_methods = { .n0 = 12, .k0 = 4, - .lhs_transposed = false, - .rhs_transposed = false, - .dst_format = DataFormat(DataType::FP32), .lhs_format = DataFormat(DataType::FP32), .packed_lhs_format = @@ -173,8 +167,8 @@ protected: const auto has_rhs_pack = method.packed_rhs_format.data_type() != DataType::UNKNOWN; const auto has_bias = method.bias_format.data_type() != DataType::UNKNOWN; - const auto lhs_h = method.lhs_transposed ? info.k : info.m; - const auto lhs_w = method.lhs_transposed ? info.m : info.k; + const auto lhs_h = info.m; + const auto lhs_w = info.k; auto lhs = fill_matrix_random(lhs_h, lhs_w, method.lhs_format, 0); std::vector ref_packed_lhs; @@ -183,8 +177,8 @@ protected: pack(method.packed_lhs_format, lhs.data(), nullptr, nullptr, method.lhs_format, lhs_h, lhs_w); } - const auto rhs_h = method.rhs_transposed ? info.n : info.k; - const auto rhs_w = method.rhs_transposed ? info.k : info.n; + const auto rhs_h = info.k; + const auto rhs_w = info.n; auto rhs = fill_matrix_random(rhs_h, rhs_w, method.rhs_format, 1); std::vector rhs_scales; @@ -223,7 +217,7 @@ protected: rhs.data(), rhs_scales.data(), nullptr, method.rhs_format.data_type(), // has_bias ? bias.data() : nullptr, nullptr, nullptr, method.bias_format.data_type(), // method.dst_format.data_type(), // - info.m, info.n, info.k, method.lhs_transposed, method.rhs_transposed); + info.m, info.n, info.k, false, false); const auto& data = _data[data_id] = { .lhs = std::move(lhs), @@ -279,7 +273,7 @@ TEST_P(MatMulTestBf16, Output) { const size_t dst_w = info.n; const bool has_bias = (data.bias.size() > 0); - const auto lhs_start_row = method.lhs_transposed ? 0 : rect.start_row(); + const auto lhs_start_row = rect.start_row(); const auto lhs_stride = method.lhs_format.default_row_stride(lhs_w); std::vector lhs_data; diff --git a/test/tests/matmul_test.cpp b/test/tests/matmul_test.cpp index 752a370147de64238ee2a4705257578f1d584323..95b55200049de423ccf1d1fc3aaa43ec7bb302cf 100644 --- a/test/tests/matmul_test.cpp +++ b/test/tests/matmul_test.cpp @@ -40,10 +40,13 @@ #include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h" #include "kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.h" // matmul_clamp_f32_f32_f32p #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.h" +#include "test/reference/transpose.hpp" + namespace kai::test { /// List of supported matrix multiplication methods. @@ -54,9 +57,6 @@ static const std::array matmul_methods = { .m0 = 6, .n0 = 16, - .lhs_transposed = false, - .rhs_transposed = false, - .dst_format = DataFormat(DataType::FP16), .lhs_format = DataFormat(DataType::FP16), .packed_lhs_format = DataFormat(DataType::UNKNOWN), @@ -86,6 +86,13 @@ static const std::array matmul_methods = { .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, .fn_pack_rhs = kai_run_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon, + .fn_pack_rhs_nxk_get_n_step = nullptr, + .fn_pack_rhs_nxk_get_rhs_offset = nullptr, + .fn_pack_rhs_nxk_get_bias_offset = nullptr, + .fn_pack_rhs_nxk_get_packed_rhs_offset = nullptr, + .fn_pack_rhs_nxk_get_packed_rhs_size = nullptr, + .fn_pack_rhs_nxk = nullptr, + .fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon, .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, @@ -102,9 +109,6 @@ static const std::array matmul_methods = { .m0 = 6, .n0 = 8, - .lhs_transposed = false, - .rhs_transposed = false, - .dst_format = DataFormat(DataType::FP32), .lhs_format = DataFormat(DataType::FP32), .packed_lhs_format = DataFormat(DataType::UNKNOWN), @@ -134,6 +138,13 @@ static const std::array matmul_methods = { .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla, .fn_pack_rhs = kai_run_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon, + .fn_pack_rhs_nxk_get_n_step = nullptr, + .fn_pack_rhs_nxk_get_rhs_offset = nullptr, + .fn_pack_rhs_nxk_get_bias_offset = nullptr, + .fn_pack_rhs_nxk_get_packed_rhs_offset = nullptr, + .fn_pack_rhs_nxk_get_packed_rhs_size = nullptr, + .fn_pack_rhs_nxk = nullptr, + .fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon, .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla, @@ -150,9 +161,6 @@ static const std::array matmul_methods = { .m0 = 2 * get_sme_vector_length(), .n0 = 2 * get_sme_vector_length(), - .lhs_transposed = false, - .rhs_transposed = false, - .dst_format = DataFormat(DataType::FP32), .lhs_format = DataFormat(DataType::FP32), .packed_lhs_format = DataFormat(DataType::FP32, 2 * get_sme_vector_length(), 1), @@ -185,6 +193,13 @@ static const std::array matmul_methods = { kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, .fn_pack_rhs = kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme, + .fn_pack_rhs_nxk_get_n_step = kai_get_n_step_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme, + .fn_pack_rhs_nxk_get_rhs_offset = kai_get_rhs_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme, + .fn_pack_rhs_nxk_get_bias_offset = kai_get_bias_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme, + .fn_pack_rhs_nxk_get_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme, + .fn_pack_rhs_nxk_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme, + .fn_pack_rhs_nxk = kai_run_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme, + .fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme, .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, @@ -210,6 +225,7 @@ protected: std::vector rhs{}; ///< RHS operand. std::vector rhs_scales{}; ///< RHS per-row quantization scales. std::vector bias{}; ///< Bias. + std::vector rhs_t{}; ///< Transposed RHS matrix. std::vector ref_packed_rhs{}; ///< Reference packed RHS. std::vector ref_dst{}; ///< Reference output. }; @@ -231,8 +247,8 @@ protected: const auto has_rhs_pack = method.packed_rhs_format.data_type() != DataType::UNKNOWN; const auto has_bias = method.bias_format.data_type() != DataType::UNKNOWN; - const auto lhs_h = method.lhs_transposed ? info.k : info.m; - const auto lhs_w = method.lhs_transposed ? info.m : info.k; + const auto lhs_h = info.m; + const auto lhs_w = info.k; auto lhs = fill_matrix_random(lhs_h, lhs_w, method.lhs_format, 0); std::vector ref_packed_lhs; @@ -241,10 +257,13 @@ protected: pack(method.packed_lhs_format, lhs.data(), nullptr, nullptr, method.lhs_format, lhs_h, lhs_w); } - const auto rhs_h = method.rhs_transposed ? info.n : info.k; - const auto rhs_w = method.rhs_transposed ? info.k : info.n; + const auto rhs_h = info.k; + const auto rhs_w = info.n; auto rhs = fill_matrix_random(rhs_h, rhs_w, method.rhs_format, 1); + KAI_ASSUME(method.rhs_format.is_raw()); + auto rhs_t = transpose(rhs.data(), method.rhs_format.data_type(), rhs_h, rhs_w); + std::vector rhs_scales; if (data_type_is_quantized(method.rhs_format.data_type()) && method.rhs_format.pack_format() == DataFormat::PackFormat::NONE) { @@ -263,7 +282,7 @@ protected: if (has_rhs_pack) { packed_rhs = matmul_pack_rhs( rhs.data(), !rhs_scales.empty() ? rhs_scales.data() : nullptr, bias.data(), method.rhs_format, - method.packed_rhs_format, info.n, info.k, !method.rhs_transposed); + method.packed_rhs_format, info.n, info.k, true); } KAI_ASSUME(method.lhs_format.is_raw()); @@ -274,7 +293,7 @@ protected: rhs.data(), rhs_scales.data(), nullptr, method.rhs_format.data_type(), // bias.data(), nullptr, nullptr, method.bias_format.data_type(), // method.dst_format.data_type(), // - info.m, info.n, info.k, method.lhs_transposed, method.rhs_transposed); + info.m, info.n, info.k, false, false); const auto& data = _data[data_id] = { .lhs = std::move(lhs), @@ -282,6 +301,7 @@ protected: .rhs = std::move(rhs), .rhs_scales = std::move(rhs_scales), .bias = std::move(bias), + .rhs_t = std::move(rhs_t), .ref_packed_rhs = std::move(packed_rhs), .ref_dst = std::move(ref_dst), }; @@ -312,8 +332,8 @@ TEST_P(MatMulTest, PackedLhs) { GTEST_SKIP(); } - const auto lhs_h = method.lhs_transposed ? info.k : info.m; - const auto lhs_w = method.lhs_transposed ? info.m : info.k; + const auto lhs_h = info.m; + const auto lhs_w = info.k; const auto rect = portion.compute_portion( lhs_h, lhs_w, method.packed_lhs_format.scheduler_block_height(lhs_h), @@ -362,7 +382,7 @@ TEST_P(MatMulTest, PackedRhs) { GTEST_SKIP(); } - const auto rhs_w = method.rhs_transposed ? info.k : info.n; + const auto rhs_w = info.n; const auto packed_rhs_h = info.n; const auto packed_rhs_w = info.k; @@ -378,8 +398,8 @@ TEST_P(MatMulTest, PackedRhs) { GTEST_SKIP(); } - const auto rhs_start_row = method.rhs_transposed ? rect.start_row() : rect.start_col(); - const auto rhs_start_col = method.rhs_transposed ? rect.start_col() : rect.start_row(); + const auto rhs_start_row = rect.start_col(); + const auto rhs_start_col = rect.start_row(); const auto ref_rhs_row_stride = method.rhs_format.default_row_stride(rhs_w); @@ -419,6 +439,68 @@ TEST_P(MatMulTest, PackedRhs) { ASSERT_TRUE(success); } +/// Tests the transposed RHS packing kernel. +TEST_P(MatMulTest, PackedTransposedRhs) { + const auto& [method, info, portion] = GetParam(); + const auto& data = test_data(); + + if (method.fn_is_supported && !method.fn_is_supported()) { + GTEST_SKIP(); + } + + if (!method.is_pack_rhs_nxk_needed()) { + GTEST_SKIP(); + } + + const auto n_step = method.fn_pack_rhs_nxk_get_n_step(); + const auto ref_n_step = method.packed_rhs_format.scheduler_block_height(info.n); + ASSERT_EQ(n_step, ref_n_step); + + const auto rect = portion.compute_portion( + info.n, info.k, method.packed_rhs_format.scheduler_block_height(info.n), + method.packed_rhs_format.scheduler_block_width(info.k)); + + if (rect.height() == 0 || rect.width() == 0) { + GTEST_SKIP(); + } + + const auto ref_rhs_row_stride = method.rhs_format.default_row_stride(info.k); + + const auto rhs_offset = method.fn_pack_rhs_nxk_get_rhs_offset(rect.start_row(), ref_rhs_row_stride); + const auto ref_rhs_offset = method.rhs_format.default_offset_in_bytes(rect.start_row(), rect.start_col(), info.k); + ASSERT_EQ(rhs_offset, ref_rhs_offset); + + const auto packed_rhs_size = method.fn_pack_rhs_nxk_get_packed_rhs_size(info.n, info.k); + const auto ref_packed_rhs_size = method.packed_rhs_format.default_size_in_bytes(info.n, info.k); + ASSERT_EQ(packed_rhs_size, ref_packed_rhs_size); + + const auto packed_rhs_offset = method.fn_pack_rhs_nxk_get_packed_rhs_offset(rect.start_row(), info.k); + const auto ref_packed_rhs_offset = + method.packed_rhs_format.default_offset_in_bytes(rect.start_row(), rect.start_col(), info.k); + ASSERT_EQ(packed_rhs_offset, ref_packed_rhs_offset); + + const auto ref_rhs_scales_offset = + rect.start_row() * data_type_size_in_bits(method.packed_rhs_format.scale_data_type()) / 8; + + const auto bias_offset = method.fn_get_bias_offset(rect.start_row()); + const auto ref_bias_offset = method.bias_format.default_offset_in_bytes(0, rect.start_row(), info.n); + ASSERT_EQ(bias_offset, ref_bias_offset); + + std::vector packed_rhs; + packed_rhs.resize(packed_rhs_size); + + method.pack_rhs_nxk( + rect.height(), rect.width(), data.rhs_t.data() + rhs_offset, ref_rhs_row_stride, data.bias.data() + bias_offset, + !data.rhs_scales.empty() ? data.rhs_scales.data() + ref_rhs_scales_offset : nullptr, + packed_rhs.data() + packed_rhs_offset); + + const auto exact = method.packed_rhs_format.pack_format() != DataFormat::PackFormat::QUANTIZE_PER_ROW; + DefaultMismatchHandler handler(0, exact ? 0 : 0.0001, 0, exact ? 0 : 0.001); + const auto success = + compare(packed_rhs.data(), data.ref_packed_rhs.data(), method.packed_rhs_format, info.n, info.k, rect, handler); + ASSERT_TRUE(success); +} + /// Tests the output. TEST_P(MatMulTest, Output) { const auto& [method, info, portion] = GetParam(); @@ -444,13 +526,13 @@ TEST_P(MatMulTest, Output) { GTEST_SKIP(); } - const auto lhs_w = method.lhs_transposed ? info.m : info.k; - const auto rhs_w = method.rhs_transposed ? info.k : info.n; + const auto lhs_w = info.k; + const auto rhs_w = info.n; const auto bias_w = info.n; const auto dst_w = info.n; - const auto lhs_start_row = method.lhs_transposed ? 0 : rect.start_row(); - const auto lhs_start_col = method.lhs_transposed ? rect.start_row() : 0; + const auto lhs_start_row = rect.start_row(); + const auto lhs_start_col = 0; const auto lhs_stride = method.lhs_format.default_row_stride(lhs_w); const uint8_t* lhs_data = nullptr; @@ -458,7 +540,11 @@ TEST_P(MatMulTest, Output) { if (method.is_pack_lhs_needed()) { lhs_data = data.ref_packed_lhs.data(); - lhs_offset = method.packed_lhs_format.default_offset_in_bytes(lhs_start_row, lhs_start_col, info.k); + + const auto ref_packed_lhs_offset = + method.packed_lhs_format.default_offset_in_bytes(lhs_start_row, lhs_start_col, info.k); + lhs_offset = method.fn_get_packed_lhs_offset(lhs_start_row, info.k); + ASSERT_EQ(lhs_offset, ref_packed_lhs_offset); } else { lhs_data = data.lhs.data(); @@ -483,8 +569,8 @@ TEST_P(MatMulTest, Output) { method.packed_rhs_format.default_offset_in_bytes(packed_rhs_start_row, packed_rhs_start_col, info.k); ASSERT_EQ(rhs_offset, ref_rhs_offset); } else { - const auto rhs_start_row = method.rhs_transposed ? rect.start_col() : 0; - const auto rhs_start_col = method.rhs_transposed ? 0 : rect.start_col(); + const auto rhs_start_row = 0; + const auto rhs_start_col = rect.start_col(); rhs_data = data.rhs.data(); rhs_offset = method.rhs_format.default_offset_in_bytes(rhs_start_row, rhs_start_col, rhs_w); @@ -524,7 +610,8 @@ INSTANTIATE_TEST_SUITE_P( MatMulShape{20, 1, 20}, // MatMulShape{6, 16, 32}, // MatMulShape{12, 32, 17}, // - MatMulShape{13, 33, 23} // + MatMulShape{13, 33, 23}, // + MatMulShape{87, 93, 56} // ), testing::Values( MatrixPortion(0, 0, 1, 1), // Full matrix.