diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 24a629a513f7fd578e6d09c5d6e51f58182f2634..b7658016c49b8dde263d9556f5c1b9f45dc4328a 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -86,7 +86,7 @@ test-linux-aarch64: dependencies: - build-clang script: - - ./build/kleidiai_test + - ./build/kleidiai_test --gtest_filter=*:-*sme* test-linux-aarch64-cov: extends: @@ -96,7 +96,7 @@ test-linux-aarch64-cov: dependencies: - build-clang-cov script: - - ./build/kleidiai_test + - ./build/kleidiai_test --gtest_filter=*:-*sme* - mkdir -p build/coverage - gcovr --gcov-executable="llvm-cov gcov" --exclude-unreachable-branches --exclude=build --exclude=test --exclude-lines-by-pattern=".*KAI_(?:ASSERT|ASSUME|ERROR).*" --exclude-branches-by-pattern=".*KAI_(?:ASSERT|ASSUME).*" --json=build/coverage/linux-aarch64.json -j --root . build artifacts: @@ -122,7 +122,7 @@ test-linux-aarch64-cov-fvp: cd '$PWD' mkdir -p artifacts - GCOV_PREFIX=artifacts ./build/kleidiai_test && echo 'FINISHED WITHOUT ERROR' + GCOV_PREFIX=artifacts ./build/kleidiai_test --gtest_filter=*sme* && echo 'FINISHED WITHOUT ERROR' tar cf artifacts.tar -C artifacts . sync @@ -166,6 +166,7 @@ test-linux-aarch64-cov-fvp: -C cluster0.sve.has_b16b16=1 \ -C cluster0.sve.has_sve2=1 \ -C cluster0.sve.has_sme=1 \ + -C cluster0.sve.has_sme2=1 \ -C cluster0.sve.has_sme_f16f16=1 \ -C cluster0.sve.has_sme_fa64=1 \ -C cluster0.sve.has_sme_lutv2=1 \ diff --git a/CMakeLists.txt b/CMakeLists.txt index 1572985c5d598770d89b75c17891111fab04982e..1f8c772938990f89eec476e52bd951ec2c670af6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -79,11 +79,22 @@ set(KLEIDIAI_FILES_NEON_I8MM kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.c ) +set(KLEIDIAI_FILES_SME + kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.c + kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c +) + +set(KLEIDIAI_FILES_SME2 + kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.c +) + add_library(kleidiai ${KLEIDIAI_FILES_NEON} ${KLEIDIAI_FILES_NEON_FP16} ${KLEIDIAI_FILES_NEON_DOTPROD} ${KLEIDIAI_FILES_NEON_I8MM} + ${KLEIDIAI_FILES_SME} + ${KLEIDIAI_FILES_SME2} ) target_include_directories(kleidiai @@ -110,11 +121,19 @@ foreach(KLEIDIAI_SOURCE_FILE IN LISTS KLEIDIAI_FILES_NEON_I8MM) set_property(SOURCE ${KLEIDIAI_SOURCE_FILE} PROPERTY COMPILE_OPTIONS -march=armv8.2-a+i8mm) endforeach() +foreach(KLEIDIAI_SOURCE_FILE IN LISTS KLEIDIAI_FILES_SME) + set_property(SOURCE ${KLEIDIAI_SOURCE_FILE} PROPERTY COMPILE_OPTIONS -march=armv8.2-a+sve+sve2) +endforeach() + +foreach(KLEIDIAI_SOURCE_FILE IN LISTS KLEIDIAI_FILES_SME2) + set_property(SOURCE ${KLEIDIAI_SOURCE_FILE} PROPERTY COMPILE_OPTIONS -march=armv8.2-a+sve+sve2) +endforeach() + if(KLEIDIAI_BUILD_TESTS) enable_testing() include(GoogleTest) - add_executable(kleidiai_test + add_library(kleidiai_test_framework test/common/data_type.cpp test/common/data_format.cpp test/common/printer.cpp @@ -124,6 +143,8 @@ if(KLEIDIAI_BUILD_TESTS) test/common/rect.cpp test/common/bfloat16.cpp test/common/float16.cpp + test/common/cpu_info.cpp + test/common/sme.cpp test/reference/binary_elementwise.cpp test/reference/matmul.cpp @@ -134,18 +155,26 @@ if(KLEIDIAI_BUILD_TESTS) test/reference/round.cpp test/reference/transpose.cpp test/reference/cast.cpp + ) - test/tests/matmul_test.cpp + target_compile_options(kleidiai_test_framework + PUBLIC ${KLEIDIAI_WARNING_FLAGS} + PUBLIC -march=armv8.2-a+fp16+bf16 + ) + + set_property(SOURCE test/common/sme.cpp PROPERTY COMPILE_OPTIONS -march=armv8.2-a+sve) + + target_link_libraries(kleidiai_test_framework + PUBLIC kleidiai ) - target_compile_options(kleidiai_test - PRIVATE ${KLEIDIAI_WARNING_FLAGS} - PRIVATE -march=armv8.2-a+fp16+bf16 + add_executable(kleidiai_test + test/tests/matmul_test.cpp ) target_link_libraries(kleidiai_test + PRIVATE kleidiai_test_framework PRIVATE GTest::gtest_main - PRIVATE kleidiai ) # Cross-compiling is a common use case which creates a conflict if DISCOVERY_MODE is set to POST_BUILD (by default) diff --git a/kai/kai_common.h b/kai/kai_common.h index 95dce8b5d277ada195d7964f1b2317fd75e6bc11..c40796faba389ee893e67e3a562f209e0bd7acde 100644 --- a/kai/kai_common.h +++ b/kai/kai_common.h @@ -52,6 +52,58 @@ inline static size_t kai_roundup(size_t a, size_t b) { return ((a + b - 1) / b) * b; } +#ifdef __ARM_FEATURE_SVE + +/// Gets the SME vector length for 8-bit elements. +inline static uint64_t kai_get_sme_vector_length_u8(void) { + uint64_t res = 0; + + __asm __volatile( + ".inst 0xd503477f // SMSTART ZA\n" + "cntb %0\n" + ".inst 0xd503467f // SMSTOP\n" + : "=r"(res) + : + : "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", + "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"); + + return res; +} + +/// Gets the SME vector length for 16-bit elements. +inline static uint64_t kai_get_sme_vector_length_u16(void) { + uint64_t res = 0; + + __asm __volatile( + ".inst 0xd503477f // SMSTART ZA\n" + "cnth %0\n" + ".inst 0xd503467f // SMSTOP\n" + : "=r"(res) + : + : "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", + "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"); + + return res; +} + +/// Gets the SME vector length for 32-bit elements. +inline static uint64_t kai_get_sme_vector_length_u32(void) { + uint64_t res = 0; + + __asm __volatile( + ".inst 0xd503477f // SMSTART ZA\n" + "cntw %0\n" + ".inst 0xd503467f // SMSTOP\n" + : "=r"(res) + : + : "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", + "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"); + + return res; +} + +#endif // __ARM_FEATURE_SVE + #ifdef __cplusplus } #endif diff --git a/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.c b/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.c new file mode 100644 index 0000000000000000000000000000000000000000..b436d46a4c5914d18e1652cd0ce75c5109c056b4 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.c @@ -0,0 +1,484 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural features check. + +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_mr = 2; +static const size_t kai_nr = 2; +static const size_t kai_kr = 1; +static const size_t kai_sr = 1; + +size_t kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(void) { + return kai_mr * kai_get_sme_vector_length_u32(); +} + +size_t kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(void) { + return kai_nr * kai_get_sme_vector_length_u32(); +} + +size_t kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(void) { + return kai_mr * kai_get_sme_vector_length_u32(); +} + +size_t kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(void) { + return kai_nr * kai_get_sme_vector_length_u32(); +} + +size_t kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(size_t m_idx, size_t k) { + KAI_ASSUME(m_idx % kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() == 0); + return m_idx * k * sizeof(float); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(size_t n_idx, size_t k) { + KAI_ASSUME(n_idx % kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() == 0); + return n_idx * (k * sizeof(float) + sizeof(float)); +} + +size_t kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSUME(m_idx % kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() == 0); + KAI_ASSUME(n_idx % kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() == 0); + + return m_idx * dst_stride + n_idx * sizeof(float); +} + +size_t kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(size_t m, size_t n) { + return m * n * sizeof(float); +} + +void kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa( + size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row, + size_t dst_stride_col, float clamp_min, float clamp_max) { + KAI_ASSUME(dst_stride_col == sizeof(float)); + + typedef struct { + const void* A; + const void* B; + + void* C; + long ldcb; + long M, N, K; + float min; + float max; + + void* accumulator_buffer; + uint64_t flags; + } KernelArgs; + + KernelArgs args; + + args.A = lhs_packed; + args.B = rhs_packed; + + args.C = dst; + args.ldcb = dst_stride_row; + args.M = m; + args.N = n; + args.K = k; + args.min = clamp_min; + args.max = clamp_max; + + args.accumulator_buffer = NULL; + args.flags = 0; + + __asm__ __volatile__( + "ldr x17, [%x[args], %[offsetof_flags]]\n" + ".inst 0xd503477f // SMSTART ZA\n" + "ptrue p0.b\n" + ".inst 0x25207811 // ptrue pn9.b\n" + "ldr x16, [%x[args], %[offsetof_accumulator_buffer]]\n" + "ldr x15, [%x[args], %[offsetof_accumulator_buffer]]\n" + "tbz x17, #0, 2f\n" + "mov x12, #0x0\n" + "cntw x20\n" + "1:" // Initial accumulator load from buffer: Loop + ".inst 0xa040c618 // ld1w { z24.s-z27.s }, pn9.b/Z, [x16]\n" + ".inst 0xa041c60c // ld1w { z12.s-z15.s }, pn9.b/Z, [x16, #0x4, MUL VL]\n" + ".inst 0xa042c600 // ld1w { z0.s-z3.s }, pn9.b/Z, [x16, #0x8, MUL VL]\n" + ".inst 0xa043c610 // ld1w { z16.s-z19.s }, pn9.b/Z, [x16, #0xc, MUL VL]\n" + ".inst 0xc0840700 // mova za0h.s[x12], { z24.s-z27.s }\n" + "addvl x16, x16, #16\n" + ".inst 0xc0840581 // mova za1h.s[x12], { z12.s-z15.s }\n" + ".inst 0xc0840402 // mova za2h.s[x12], { z0.s-z3.s }\n" + ".inst 0xc0840603 // mova za3h.s[x12], { z16.s-z19.s }\n" + "add x12, x12, #0x4\n" + "cmp x12, x20\n" + "blt 1b\n" + "2:" // Initial accumulator load from buffer: End + "ldr w14, [%x[args], %[offsetof_M]]\n" + "mov x13, #0x0\n" + "mov x11, #0x0\n" + "ldr w10, [%x[args], %[offsetof_N]]\n" + "ldr x9, [%x[args], %[offsetof_A]]\n" + "3:" // M loop + "ldr x28, [%x[args], %[offsetof_B]]\n" + "4:" // N loop + "mov x27, x9\n" + ".inst 0x25aa4570 // whilelt pn8.s, x11, x10, VLx2\n" + "tbnz x17, #0, 5f\n" + "fmov z17.s, #1.0\n" + ".inst 0xa040438a // ld1w { z10.s-z11.s }, p8/Z, [x28]\n" // Load bias + ".inst 0xc00800ff // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 }\n" + "addvl x28, x28, #2\n" + ".inst 0x808a0220 // fmopa za0.s, p0/M, p0/M, z17.s, z10.s\n" + ".inst 0x808b0221 // fmopa za1.s, p0/M, p0/M, z17.s, z11.s\n" + ".inst 0x808a0222 // fmopa za2.s, p0/M, p0/M, z17.s, z10.s\n" + ".inst 0x808b0223 // fmopa za3.s, p0/M, p0/M, z17.s, z11.s\n" + "5:" // Prepare accumulators: Test for last block + "mov x20, x11\n" + "mov x21, x13\n" + "incw x20, ALL, MUL #2\n" + "incw x21, ALL, MUL #2\n" + "cmp x20, x10\n" + "mov x20, x17\n" + "csel x21, x13, x21, LT\n" + "bfm x17, XZR, #0x0, #0x0 // bfc x17, #0x0, #0x1\n" + "cmp x21, x14\n" + "csel x17, x20, x17, LT\n" + "ldr x20, [%x[args], %[offsetof_K]]\n" + "lsr x21, x20, #0x2\n" + "and x20, x20, #0x3\n" + "cbz x21, 9f\n" + "subs x21, x21, #0x1\n" + ".inst 0xa0404776 // ld1w { z22.s-z23.s }, pn9.b/Z, [x27]\n" + ".inst 0xa1404787 // ld1w { z7.s, z15.s }, pn9.b/Z, [x28]\n" + ".inst 0xa1414766 // ld1w { z6.s, z14.s }, pn9.b/Z, [x27, #0x2, MUL VL]\n" + ".inst 0xa0414794 // ld1w { z20.s-z21.s }, pn9.b/Z, [x28, #0x2, MUL VL]\n" + ".inst 0xa1424762 // ld1w { z2.s, z10.s }, pn9.b/Z, [x27, #0x4, MUL VL]\n" + ".inst 0xa1424783 // ld1w { z3.s, z11.s }, pn9.b/Z, [x28, #0x4, MUL VL]\n" + ".inst 0xa1434761 // ld1w { z1.s, z9.s }, pn9.b/Z, [x27, #0x6, MUL VL]\n" + "addvl x27, x27, #8\n" + ".inst 0xa0434784 // ld1w { z4.s-z5.s }, pn9.b/Z, [x28, #0x6, MUL VL]\n" + "addvl x28, x28, #8\n" + "ble 8f\n" + "7:" // K loop + ".inst 0x808702c0 // fmopa za0.s, p0/M, p0/M, z22.s, z7.s\n" + "subs x21, x21, #0x1\n" + ".inst 0x808f02c1 // fmopa za1.s, p0/M, p0/M, z22.s, z15.s\n" + ".inst 0x808702e2 // fmopa za2.s, p0/M, p0/M, z23.s, z7.s\n" + ".inst 0x808f02e3 // fmopa za3.s, p0/M, p0/M, z23.s, z15.s\n" + ".inst 0xa0404776 // ld1w { z22.s-z23.s }, pn9.b/Z, [x27]\n" + ".inst 0x809400c0 // fmopa za0.s, p0/M, p0/M, z6.s, z20.s\n" + ".inst 0xa1404787 // ld1w { z7.s, z15.s }, pn9.b/Z, [x28]\n" + ".inst 0x809500c1 // fmopa za1.s, p0/M, p0/M, z6.s, z21.s\n" + ".inst 0x809401c2 // fmopa za2.s, p0/M, p0/M, z14.s, z20.s\n" + ".inst 0x809501c3 // fmopa za3.s, p0/M, p0/M, z14.s, z21.s\n" + ".inst 0xa1414766 // ld1w { z6.s, z14.s }, pn9.b/Z, [x27, #0x2, MUL VL]\n" + ".inst 0x80830040 // fmopa za0.s, p0/M, p0/M, z2.s, z3.s\n" + ".inst 0xa0414794 // ld1w { z20.s-z21.s }, pn9.b/Z, [x28, #0x2, MUL VL]\n" + ".inst 0x808b0041 // fmopa za1.s, p0/M, p0/M, z2.s, z11.s\n" + ".inst 0x80830142 // fmopa za2.s, p0/M, p0/M, z10.s, z3.s\n" + ".inst 0x808b0143 // fmopa za3.s, p0/M, p0/M, z10.s, z11.s\n" + ".inst 0xa1424762 // ld1w { z2.s, z10.s }, pn9.b/Z, [x27, #0x4, MUL VL]\n" + ".inst 0xa1424783 // ld1w { z3.s, z11.s }, pn9.b/Z, [x28, #0x4, MUL VL]\n" + ".inst 0x80840020 // fmopa za0.s, p0/M, p0/M, z1.s, z4.s\n" + ".inst 0x80850021 // fmopa za1.s, p0/M, p0/M, z1.s, z5.s\n" + ".inst 0x80840122 // fmopa za2.s, p0/M, p0/M, z9.s, z4.s\n" + ".inst 0x80850123 // fmopa za3.s, p0/M, p0/M, z9.s, z5.s\n" + ".inst 0xa1434761 // ld1w { z1.s, z9.s }, pn9.b/Z, [x27, #0x6, MUL VL]\n" + "addvl x27, x27, #8\n" + ".inst 0xa0434784 // ld1w { z4.s-z5.s }, pn9.b/Z, [x28, #0x6, MUL VL]\n" + "addvl x28, x28, #8\n" + "bgt 7b\n" + "8:" // K loop tail + ".inst 0x808702c0 // fmopa za0.s, p0/M, p0/M, z22.s, z7.s\n" + ".inst 0x808f02c1 // fmopa za1.s, p0/M, p0/M, z22.s, z15.s\n" + ".inst 0x808702e2 // fmopa za2.s, p0/M, p0/M, z23.s, z7.s\n" + ".inst 0x808f02e3 // fmopa za3.s, p0/M, p0/M, z23.s, z15.s\n" + ".inst 0x809400c0 // fmopa za0.s, p0/M, p0/M, z6.s, z20.s\n" + ".inst 0x809500c1 // fmopa za1.s, p0/M, p0/M, z6.s, z21.s\n" + ".inst 0x809401c2 // fmopa za2.s, p0/M, p0/M, z14.s, z20.s\n" + ".inst 0x809501c3 // fmopa za3.s, p0/M, p0/M, z14.s, z21.s\n" + ".inst 0x80830040 // fmopa za0.s, p0/M, p0/M, z2.s, z3.s\n" + ".inst 0x808b0041 // fmopa za1.s, p0/M, p0/M, z2.s, z11.s\n" + ".inst 0x80830142 // fmopa za2.s, p0/M, p0/M, z10.s, z3.s\n" + ".inst 0x808b0143 // fmopa za3.s, p0/M, p0/M, z10.s, z11.s\n" + ".inst 0x80840020 // fmopa za0.s, p0/M, p0/M, z1.s, z4.s\n" + ".inst 0x80850021 // fmopa za1.s, p0/M, p0/M, z1.s, z5.s\n" + ".inst 0x80840122 // fmopa za2.s, p0/M, p0/M, z9.s, z4.s\n" + ".inst 0x80850123 // fmopa za3.s, p0/M, p0/M, z9.s, z5.s\n" + "9:" // K oddments + "cbz x20, 11f\n" + "10:" // K oddments: Loop + ".inst 0xa040476a // ld1w { z10.s-z11.s }, pn9.b/Z, [x27]\n" + "subs x20, x20, #0x1\n" + "addvl x27, x27, #2\n" + ".inst 0xa040478e // ld1w { z14.s-z15.s }, pn9.b/Z, [x28]\n" + "addvl x28, x28, #2\n" + ".inst 0x808e0140 // fmopa za0.s, p0/M, p0/M, z10.s, z14.s\n" + ".inst 0x808f0141 // fmopa za1.s, p0/M, p0/M, z10.s, z15.s\n" + ".inst 0x808e0162 // fmopa za2.s, p0/M, p0/M, z11.s, z14.s\n" + ".inst 0x808f0163 // fmopa za3.s, p0/M, p0/M, z11.s, z15.s\n" + "bgt 10b\n" + "11:" // K oddments: End + "tbz x17, #1, 15f\n" + "tbz x17, #0, 13f\n" + "mov x12, #0x0\n" + "cntw x20\n" + "12:" // Store to partial result buffer: Store and refill: Loop + ".inst 0xa040c600 // ld1w { z0.s-z3.s }, pn9.b/Z, [x16]\n" + ".inst 0xc0860414 // mova { z20.s-z23.s }, za0h.s[x12]\n" + ".inst 0xc086043c // mova { z28.s-z31.s }, za1h.s[x12]\n" + ".inst 0xa041c604 // ld1w { z4.s-z7.s }, pn9.b/Z, [x16, #0x4, MUL VL]\n" + ".inst 0xc0860448 // mova { z8.s-z11.s }, za2h.s[x12]\n" + ".inst 0xc086046c // mova { z12.s-z15.s }, za3h.s[x12]\n" + ".inst 0xa042c610 // ld1w { z16.s-z19.s }, pn9.b/Z, [x16, #0x8, MUL VL]\n" + ".inst 0xa043c618 // ld1w { z24.s-z27.s }, pn9.b/Z, [x16, #0xc, MUL VL]\n" + ".inst 0xc0840400 // mova za0h.s[x12], { z0.s-z3.s }\n" + "addvl x16, x16, #16\n" + ".inst 0xc0840481 // mova za1h.s[x12], { z4.s-z7.s }\n" + ".inst 0xa060c5f4 // st1w { z20.s-z23.s }, pn9.b, [x15]\n" + ".inst 0xc0840602 // mova za2h.s[x12], { z16.s-z19.s }\n" + ".inst 0xa061c5fc // st1w { z28.s-z31.s }, pn9.b, [x15, #0x4, MUL VL]\n" + ".inst 0xc0840703 // mova za3h.s[x12], { z24.s-z27.s }\n" + "add x12, x12, #0x4\n" + ".inst 0xa062c5e8 // st1w { z8.s-z11.s }, pn9.b, [x15, #0x8, MUL VL]\n" + "cmp x12, x20\n" + ".inst 0xa063c5ec // st1w { z12.s-z15.s }, pn9.b, [x15, #0xc, MUL VL]\n" + "addvl x15, x15, #16\n" + "blt 12b\n" + "b 31f\n" + "13:" // Store to partial result buffer: Store only + "mov x12, #0x0\n" + "cntw x20\n" + "14:" // Store to partial result buffer: Store only: Loop + ".inst 0xc0860400 // mova { z0.s-z3.s }, za0h.s[x12]\n" + ".inst 0xc0860430 // mova { z16.s-z19.s }, za1h.s[x12]\n" + ".inst 0xc086045c // mova { z28.s-z31.s }, za2h.s[x12]\n" + ".inst 0xc0860474 // mova { z20.s-z23.s }, za3h.s[x12]\n" + ".inst 0xa060c5e0 // st1w { z0.s-z3.s }, pn9.b, [x15]\n" + "add x12, x12, #0x4\n" + ".inst 0xa061c5f0 // st1w { z16.s-z19.s }, pn9.b, [x15, #0x4, MUL VL]\n" + "cmp x12, x20\n" + ".inst 0xa062c5fc // st1w { z28.s-z31.s }, pn9.b, [x15, #0x8, MUL VL]\n" + ".inst 0xa063c5f4 // st1w { z20.s-z23.s }, pn9.b, [x15, #0xc, MUL VL]\n" + "addvl x15, x15, #16\n" + "blt 14b\n" + "b 31f\n" + "15:" // Store to output array + "ldr x26, [%x[args], %[offsetof_C]]\n" + "sub x25, x14, x13\n" + "ldr x24, [%x[args], %[offsetof_ldcb]]\n" + "add x26, x26, x11, LSL #2\n" // C += n + "madd x26, x13, x24, x26\n" // C += m * ldc + "tbz x17, #2, 22f\n" + "cntw x23\n" + "mov x12, #0x0\n" + "cmp x25, x23\n" + "csel x22, x25, x23, LT\n" + "lsr x21, x22, #0x2\n" + "and x20, x22, #0x3\n" + "cbz x21, 17f\n" + "16:" // Store to output array: Skip activation: Accumulator row 0 loop + ".inst 0xc0860404 // mova { z4.s-z7.s }, za0h.s[x12]\n" + ".inst 0xc086042c // mova { z12.s-z15.s }, za1h.s[x12]\n" + ".inst 0xa1604344 // st1w { z4.s, z12.s }, p8, [x26]\n" + "add x26, x26, x24\n" + "add x12, x12, #0x4\n" + ".inst 0xa1604345 // st1w { z5.s, z13.s }, p8, [x26]\n" + "add x26, x26, x24\n" + "cmp x12, x21, LSL #2\n" + ".inst 0xa1604346 // st1w { z6.s, z14.s }, p8, [x26]\n" + "add x26, x26, x24\n" + ".inst 0xa1604347 // st1w { z7.s, z15.s }, p8, [x26]\n" + "add x26, x26, x24\n" + "blt 16b\n" + "17:" // Store to output array: Skip activation: Accumulator row 0 oddments + "cbz x20, 18f\n" + "subs x20, x20, #0x1\n" + ".inst 0xc0860400 // mova { z0.s-z3.s }, za0h.s[x12]\n" + ".inst 0xc0860428 // mova { z8.s-z11.s }, za1h.s[x12]\n" + ".inst 0xa1604340 // st1w { z0.s, z8.s }, p8, [x26]\n" + "add x26, x26, x24\n" + "beq 18f\n" + "subs x20, x20, #0x1\n" + ".inst 0xa1604341 // st1w { z1.s, z9.s }, p8, [x26]\n" + "add x26, x26, x24\n" + "beq 18f\n" + ".inst 0xa1604342 // st1w { z2.s, z10.s }, p8, [x26]\n" + "add x26, x26, x24\n" + "18:" // Store to output array: Skip activation: Accumulator row 0 oddments: End + "subs x25, x25, x22\n" + "beq 22f\n" + "cmp x25, x23\n" + "mov x12, #0x0\n" + "csel x22, x25, x23, LT\n" + "lsr x21, x22, #0x2\n" + "and x20, x22, #0x3\n" + "cbz x21, 20f\n" + "19:" // Store to output array: Skip activation: Accumulator row 1 loop + ".inst 0xc0860444 // mova { z4.s-z7.s }, za2h.s[x12]\n" + ".inst 0xc086046c // mova { z12.s-z15.s }, za3h.s[x12]\n" + ".inst 0xa1604344 // st1w { z4.s, z12.s }, p8, [x26]\n" + "add x26, x26, x24\n" + "add x12, x12, #0x4\n" + ".inst 0xa1604345 // st1w { z5.s, z13.s }, p8, [x26]\n" + "add x26, x26, x24\n" + "cmp x12, x21, LSL #2\n" + ".inst 0xa1604346 // st1w { z6.s, z14.s }, p8, [x26]\n" + "add x26, x26, x24\n" + ".inst 0xa1604347 // st1w { z7.s, z15.s }, p8, [x26]\n" + "add x26, x26, x24\n" + "blt 19b\n" + "20:" // Store to output array: Skip activation: Accumulator row 1 oddments + "cbz x20, 21f\n" + "subs x20, x20, #0x1\n" + ".inst 0xc0860444 // mova { z4.s-z7.s }, za2h.s[x12]\n" + ".inst 0xc086046c // mova { z12.s-z15.s }, za3h.s[x12]\n" + ".inst 0xa1604344 // st1w { z4.s, z12.s }, p8, [x26]\n" + "add x26, x26, x24\n" + "beq 21f\n" + "subs x20, x20, #0x1\n" + ".inst 0xa1604345 // st1w { z5.s, z13.s }, p8, [x26]\n" + "add x26, x26, x24\n" + "beq 21f\n" + ".inst 0xa1604346 // st1w { z6.s, z14.s }, p8, [x26]\n" + "add x26, x26, x24\n" + "21:" // Store to output array: Skip activation: Accumulator row 1 oddments: End + "subs x25, x25, x22\n" + "beq 22f\n" + "b 29f\n" + "22:" // Store to output array: Skip activation: End + "cntw x23\n" + "ld1rw { z21.s }, p0/Z, [%x[args], %[offsetof_KernelArgs_min]]\n" + "mov x12, #0x0\n" + "cmp x25, x23\n" + "ld1rw { z20.s }, p0/Z, [%x[args], %[offsetof_KernelArgs_max]]\n" + "csel x22, x25, x23, LT\n" + "lsr x21, x22, #0x2\n" + "and x20, x22, #0x3\n" + "cbz x21, 24f\n" + "23:" // Store to output array: Accumulator row 0 loop + ".inst 0xc0860410 // mova { z16.s-z19.s }, za0h.s[x12]\n" + ".inst 0xc0860438 // mova { z24.s-z27.s }, za1h.s[x12]\n" + ".inst 0xc1b4cab0 // fclamp { z16.s-z19.s }, z21.s, z20.s\n" + ".inst 0xc1b4cab8 // fclamp { z24.s-z27.s }, z21.s, z20.s\n" + "add x12, x12, #0x4\n" + "cmp x12, x21, LSL #2\n" + ".inst 0xa1604350 // st1w { z16.s, z24.s }, p8, [x26]\n" + "add x26, x26, x24\n" + ".inst 0xa1604351 // st1w { z17.s, z25.s }, p8, [x26]\n" + "add x26, x26, x24\n" + ".inst 0xa1604352 // st1w { z18.s, z26.s }, p8, [x26]\n" + "add x26, x26, x24\n" + ".inst 0xa1604353 // st1w { z19.s, z27.s }, p8, [x26]\n" + "add x26, x26, x24\n" + "blt 23b\n" + "24:" // Store to output array: Accumulator row 0 oddments + "cbz x20, 25f\n" + ".inst 0xc0860410 // mova { z16.s-z19.s }, za0h.s[x12]\n" + ".inst 0xc0860438 // mova { z24.s-z27.s }, za1h.s[x12]\n" + "subs x20, x20, #0x1\n" + ".inst 0xc1b4cab0 // fclamp { z16.s-z19.s }, z21.s, z20.s\n" + ".inst 0xc1b4cab8 // fclamp { z24.s-z27.s }, z21.s, z20.s\n" + ".inst 0xa1604350 // st1w { z16.s, z24.s }, p8, [x26]\n" + "add x26, x26, x24\n" + "beq 25f\n" + "subs x20, x20, #0x1\n" + ".inst 0xa1604351 // st1w { z17.s, z25.s }, p8, [x26]\n" + "add x26, x26, x24\n" + "beq 25f\n" + ".inst 0xa1604352 // st1w { z18.s, z26.s }, p8, [x26]\n" + "add x26, x26, x24\n" + "25:" // Store to output array: Accumulator row 0 oddments: End + "subs x25, x25, x22\n" + "beq 29f\n" + "cmp x25, x23\n" + "mov x12, #0x0\n" + "csel x20, x25, x23, LT\n" + "lsr x21, x20, #0x2\n" + "and x20, x20, #0x3\n" + "cbz x21, 27f\n" + "26:" // Store to output array: Accumulator row 1 loop + ".inst 0xc0860440 // mova { z0.s-z3.s }, za2h.s[x12]\n" + ".inst 0xc0860468 // mova { z8.s-z11.s }, za3h.s[x12]\n" + ".inst 0xc1b4caa0 // fclamp { z0.s-z3.s }, z21.s, z20.s\n" + ".inst 0xc1b4caa8 // fclamp { z8.s-z11.s }, z21.s, z20.s\n" + "add x12, x12, #0x4\n" + "cmp x12, x21, LSL #2\n" + ".inst 0xa1604340 // st1w { z0.s, z8.s }, p8, [x26]\n" + "add x26, x26, x24\n" + ".inst 0xa1604341 // st1w { z1.s, z9.s }, p8, [x26]\n" + "add x26, x26, x24\n" + ".inst 0xa1604342 // st1w { z2.s, z10.s }, p8, [x26]\n" + "add x26, x26, x24\n" + ".inst 0xa1604343 // st1w { z3.s, z11.s }, p8, [x26]\n" + "add x26, x26, x24\n" + "blt 26b\n" + "27:" // Store to output array: Accumulator row 1 oddments + "cbz x20, 28f\n" + ".inst 0xc0860450 // mova { z16.s-z19.s }, za2h.s[x12]\n" + ".inst 0xc0860478 // mova { z24.s-z27.s }, za3h.s[x12]\n" + "subs x20, x20, #0x1\n" + ".inst 0xc1b4cab0 // fclamp { z16.s-z19.s }, z21.s, z20.s\n" + ".inst 0xc1b4cab8 // fclamp { z24.s-z27.s }, z21.s, z20.s\n" + ".inst 0xa1604350 // st1w { z16.s, z24.s }, p8, [x26]\n" + "add x26, x26, x24\n" + "beq 28f\n" + "subs x20, x20, #0x1\n" + ".inst 0xa1604351 // st1w { z17.s, z25.s }, p8, [x26]\n" + "add x26, x26, x24\n" + "beq 28f\n" + ".inst 0xa1604352 // st1w { z18.s, z26.s }, p8, [x26]\n" + "28:" // Store to output array: Accumulator row 1 oddments: End + "29:" // Store to output array: End + "tbz x17, #0, 31f\n" + "mov x12, #0x0\n" + "cntw x20\n" + "30:" // Store to output array: Refill accumulators: Loop + ".inst 0xa040c608 // ld1w { z8.s-z11.s }, pn9.b/Z, [x16]\n" + ".inst 0xa041c600 // ld1w { z0.s-z3.s }, pn9.b/Z, [x16, #0x4, MUL VL]\n" + ".inst 0xa042c604 // ld1w { z4.s-z7.s }, pn9.b/Z, [x16, #0x8, MUL VL]\n" + ".inst 0xa043c60c // ld1w { z12.s-z15.s }, pn9.b/Z, [x16, #0xc, MUL VL]\n" + ".inst 0xc0840500 // mova za0h.s[x12], { z8.s-z11.s }\n" + "addvl x16, x16, #16\n" + ".inst 0xc0840401 // mova za1h.s[x12], { z0.s-z3.s }\n" + ".inst 0xc0840482 // mova za2h.s[x12], { z4.s-z7.s }\n" + ".inst 0xc0840583 // mova za3h.s[x12], { z12.s-z15.s }\n" + "add x12, x12, #0x4\n" + "cmp x12, x20\n" + "blt 30b\n" + "31:" // End block + "incw x11, ALL, MUL #2\n" + "cmp x11, x10\n" + "blt 4b\n" + "incw x13, ALL, MUL #2\n" + "mov x11, #0x0\n" + "cmp x13, x14\n" + "mov x9, x27\n" + "blt 3b\n" + ".inst 0xd503467f // SMSTOP\n" + : + : [args] "r"(&args), [offsetof_A] "I"(offsetof(KernelArgs, A)), [offsetof_B] "I"(offsetof(KernelArgs, B)), + [offsetof_C] "I"(offsetof(KernelArgs, C)), [offsetof_K] "I"(offsetof(KernelArgs, K)), + [offsetof_KernelArgs_max] "I"(offsetof(KernelArgs, max)), + [offsetof_KernelArgs_min] "I"(offsetof(KernelArgs, min)), [offsetof_M] "I"(offsetof(KernelArgs, M)), + [offsetof_N] "I"(offsetof(KernelArgs, N)), + [offsetof_accumulator_buffer] "I"(offsetof(KernelArgs, accumulator_buffer)), + [offsetof_flags] "I"(offsetof(KernelArgs, flags)), [offsetof_ldcb] "I"(offsetof(KernelArgs, ldcb)) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", + "p15", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x20", "x21", "x22", "x23", "x24", "x25", + "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", + "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", + "z29", "z30", "z31"); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h b/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h new file mode 100644 index 0000000000000000000000000000000000000000..1dcc3404fdaf6340f9365ce6300b2114ecce18d8 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h @@ -0,0 +1,121 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Micro-kernel dependencies +/// +/// -# kai_lhs_pack_f32p2vlx1_f32_sme to pack the LHS matrix. +/// -# kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme to pack the RHS matrix. + +/// Gets m step value. +/// +/// The starting row index must be divisible by `m_step`. +/// +/// @return The m step value. +size_t kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(void); + +/// Gets n step value. +/// +/// The starting column index must be divisible by `n_step`. +/// +/// @return The n step value. +size_t kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(void); + +/// Gets mr value. +/// +/// This is the packing parameter which must be used to pack the LHS matrix. +/// +/// @return The mr value. +size_t kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(void); + +/// Gets nr value. +/// +/// This is the packing parameter which must be used to pack the RHS matrix. +/// +/// @return The nr value. +size_t kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(void); + +/// Gets kr value. +/// +/// This is the packing parameter which must be used to pack the LHS and RHS matrix. +/// +/// @return The kr value. +size_t kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(void); + +/// Gets sr value. +/// +/// This is the packing parameter which must be used to pack the LHS and RHS matrix. +/// +/// @return The sr value. +size_t kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(void); + +/// Gets the offset in bytes to the data element in the packed LHS matrix buffer. +/// +/// @param[in] m_idx Row index in the unpacked LHS matrix. +/// @param[in] k Number of columns in the unpacked LHS matrix. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(size_t m_idx, size_t k); + +/// Gets the offset in bytes to the data element in the packed RHS matrix buffer. +/// +/// @param[in] n_idx Column index in the unpacked RHS matrix. +/// @param[in] k Number of rows in the unpacked RHS matrix. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(size_t n_idx, size_t k); + +/// Gets the offset in bytes to the data element in the destination matrix buffer. +/// +/// @param[in] m_idx Row index. +/// @param[in] n_idx Column index. +/// @param[in] stride Row stride in bytes. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa( + size_t m_idx, size_t n_idx, size_t dst_stride); + +/// Gets the size in bytes of the destination matrix buffer. +/// +/// @param[in] m Number of rows. +/// @param[in] n Number of columns. +/// +/// @return The size in bytes of the destination matrix buffer. +size_t kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(size_t m, size_t n); + +/// Runs the matrix multiplication microkernel followed by a clamp operation. +/// +/// The pointer of each buffers (packed LHS, packed RHS and output) needs to be added with offset +/// calculated using the following functions: +/// +/// * Packed LHS: @ref kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa. +/// * Packed RHS: @ref kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa. +/// * Output: @ref kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa. +/// +/// @param[in] m Number of output rows to be computed. +/// @param[in] n Number of output columns to be computed. +/// @param[in] k Common dimension of the LHS and RHS operands. +/// @param[in] packed_lhs Packed LHS matrix buffer. +/// @param[in] packed_rhs Packed RHS matrix buffer. +/// @param[out] dst Output matrix buffer. +/// @param[in] dst_stride_row Row stride in bytes of the output matrix. +/// @param[in] dst_stride_col Column stride in bytes of the output matrix. +/// @param[in] clamp_min Minimum value to clamp the final result. +/// @param[in] clamp_max Maximum value to clamp the final result. +void kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa( + size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row, + size_t dst_stride_col, float clamp_min, float clamp_max); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c b/kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c new file mode 100644 index 0000000000000000000000000000000000000000..8f0d0ecfc13357fd59e951c75cfa5245967048b1 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c @@ -0,0 +1,338 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural features check. + +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_mr = 2; +static const size_t kai_kr = 1; +static const size_t kai_sr = 1; + +size_t kai_get_m_step_lhs_pack_f32p2vlx1_f32_sme(size_t mr) { + KAI_ASSUME(mr == kai_mr * kai_get_sme_vector_length_u32()); + KAI_UNUSED(mr); + + return kai_mr * kai_get_sme_vector_length_u32(); +} + +size_t kai_get_lhs_offset_lhs_pack_f32p2vlx1_f32_sme(size_t m_idx, size_t lhs_stride) { + KAI_ASSUME(m_idx % (kai_mr * kai_get_sme_vector_length_u32()) == 0); + + return m_idx * lhs_stride; +} + +size_t kai_get_lhs_packed_offset_lhs_pack_f32p2vlx1_f32_sme(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) { + const size_t scaled_mr = kai_mr * kai_get_sme_vector_length_u32(); + KAI_ASSUME(m_idx % scaled_mr == 0); + KAI_ASSUME(mr == scaled_mr); + KAI_ASSUME(kr == kai_kr); + KAI_ASSUME(sr == kai_sr); + + KAI_UNUSED(mr); + KAI_UNUSED(kr); + KAI_UNUSED(sr); + + return m_idx * k * sizeof(float); +} + +size_t kai_get_lhs_packed_size_lhs_pack_f32p2vlx1_f32_sme(size_t m, size_t k, size_t mr, size_t kr, size_t sr) { + KAI_ASSUME(mr == kai_mr * kai_get_sme_vector_length_u32()); + KAI_ASSUME(kr == kai_kr); + KAI_ASSUME(sr == kai_sr); + + KAI_UNUSED(mr); + KAI_UNUSED(kr); + KAI_UNUSED(sr); + + return kai_roundup(m, kai_mr * kai_get_sme_vector_length_u32()) * k * sizeof(float); +} + +void kai_run_lhs_pack_f32p2vlx1_f32_sme( + size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride, + void* lhs_packed) { + KAI_ASSUME(mr == kai_mr * kai_get_sme_vector_length_u32()); + KAI_ASSUME(kr == kai_kr); + KAI_ASSUME(sr == kai_sr); + KAI_ASSUME(lhs != NULL); + KAI_ASSUME(lhs_packed != NULL); + + KAI_ASSUME(m_idx_start == 0); + + const size_t block_height = kai_mr * kai_get_sme_vector_length_u32(); + const size_t width = k; + const size_t row_offset = 0; + + const void* in[block_height]; + + for (size_t block_y = 0; block_y < m; block_y += block_height) { + const size_t height = KAI_MIN(m - block_y, block_height); + void* out = lhs_packed + block_y * k * sizeof(float); + + for (size_t y = 0; y < height; y++) { + in[y] = lhs + (block_y + y) * lhs_stride; + } + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "mov x21, %x[width]\n" + "mov x20, %x[width]\n" + "incw x21\n" + "cntw x17\n" + "sub x21, x21, #0x1\n" + "sub x16, x17, #0x1\n" + "udiv x21, x21, x17\n" // n_passes = ceildiv(width, VL) + "ands x16, x20, x16\n" + "sub x20, x21, #0x1\n" + "sub x15, x17, #0x2\n" + "mov x14, #0x0\n" + "mov x11, %x[in]\n" + "add x10, %x[in], x17, LSL #3\n" + "cntw x9, ALL, MUL #2\n" + "ldr x28, [x11, #0x0]\n" + "cntw x27, ALL, MUL #3\n" + "lsr x20, x20, #0x1\n" // n_loops = (n_passes - 1) / 2 + "ldr x26, [x10, #0x0]\n" + "and x25, x21, #0x1\n" // odd_tail = bool(n_passes & 0x1) + "csel x16, x16, x17, NE\n" + "ldr x24, [x11, #0x8]\n" + "ptrue p12.s\n" + "whilelt p11.s, XZR, %x[height]\n" + "ldr x21, [x10, #0x8]\n" + "whilelt p10.s, x17, %x[height]\n" + "mov x23, %x[row_offset]\n" + "mov x22, %x[out]\n" + "whilelt p9.s, x14, %x[width]\n" + "whilelt p8.s, x14, %x[width]\n" + "add x11, x11, #0x10\n" + "add x10, x10, #0x10\n" + "mov x12, #0x0\n" + "cbz x15, 2f\n" + "1:" // K loop: Charge: Loop + ".inst 0x25306163 // psel p3.s, p8.s/Z, p11.s[w12]\n" + ".inst 0x25306142 // psel p2.s, p8.s/Z, p10.s[w12]\n" + ".inst 0x25706161 // psel p1.s, p8.s/Z, p11.s[w12, #1]\n" + ".inst 0x25706140 // psel p0.s, p8.s/Z, p10.s[w12, #1]\n" + ".inst 0xe0970f80 // ld1w { za0h.s[x12] }, p3/Z, [x28, x23, LSL #2]\n" + "ldr x28, [x11, #0x0]\n" + ".inst 0xe0970b44 // ld1w { za1h.s[x12] }, p2/Z, [x26, x23, LSL #2]\n" + "ldr x26, [x10, #0x0]\n" + ".inst 0xe0970701 // ld1w { za0h.s[x12, #1] }, p1/Z, [x24, x23, LSL #2]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe09702a5 // ld1w { za1h.s[x12, #1] }, p0/Z, [x21, x23, LSL #2]\n" + "add x12, x12, #0x2\n" + "ldr x21, [x10, #0x8]\n" + "add x10, x10, #0x10\n" + "cmp x12, x15\n" + "blt 1b\n" + "2:" // K loop: Charge: End + ".inst 0x25306163 // psel p3.s, p8.s/Z, p11.s[w12]\n" + ".inst 0x25306142 // psel p2.s, p8.s/Z, p10.s[w12]\n" + ".inst 0x25706161 // psel p1.s, p8.s/Z, p11.s[w12, #1]\n" + ".inst 0x25706140 // psel p0.s, p8.s/Z, p10.s[w12, #1]\n" + "mov x11, %x[in]\n" + "add x10, %x[in], x17, LSL #3\n" + ".inst 0xe0970f80 // ld1w { za0h.s[x12] }, p3/Z, [x28, x23, LSL #2]\n" + "ldr x28, [x11, #0x0]\n" + "incw x14\n" + ".inst 0xe0970b44 // ld1w { za1h.s[x12] }, p2/Z, [x26, x23, LSL #2]\n" + "ldr x26, [x10, #0x0]\n" + ".inst 0xe0970701 // ld1w { za0h.s[x12, #1] }, p1/Z, [x24, x23, LSL #2]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe09702a5 // ld1w { za1h.s[x12, #1] }, p0/Z, [x21, x23, LSL #2]\n" + "ldr x21, [x10, #0x8]\n" + "add x10, x10, #0x10\n" + "incw x23\n" + "cbz x20, 8f\n" + "mov x20, x20\n" + "3:" // K loop: Main loop + "whilelt p8.s, x14, %x[width]\n" + "mov x13, #0x0\n" + "cbz x15, 5f\n" + "4:" // K loop: Main loop: First: Loop + ".inst 0x25316160 // psel p0.s, p8.s/Z, p11.s[w13]\n" + ".inst 0x25316142 // psel p2.s, p8.s/Z, p10.s[w13]\n" + ".inst 0x25716161 // psel p1.s, p8.s/Z, p11.s[w13, #1]\n" + ".inst 0x25716143 // psel p3.s, p8.s/Z, p10.s[w13, #1]\n" + ".inst 0xe0972388 // ld1w { za2h.s[x13] }, p0/Z, [x28, x23, LSL #2]\n" + ".inst 0x25317120 // psel p0.s, p12.s/Z, p9.s[w13]\n" + "ldr x28, [x11, #0x0]\n" + ".inst 0xe0972b4c // ld1w { za3h.s[x13] }, p2/Z, [x26, x23, LSL #2]\n" + ".inst 0x25317122 // psel p2.s, p12.s/Z, p9.s[w13]\n" + "ldr x26, [x10, #0x0]\n" + ".inst 0xe0972709 // ld1w { za2h.s[x13, #1] }, p1/Z, [x24, x23, LSL #2]\n" + ".inst 0x25717121 // psel p1.s, p12.s/Z, p9.s[w13, #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe0972ead // ld1w { za3h.s[x13, #1] }, p3/Z, [x21, x23, LSL #2]\n" + "ldr x21, [x10, #0x8]\n" + ".inst 0xe0bfa2c0 // st1w { za0v.s[x13] }, p0/Z, [x22, XZR, LSL #2]\n" + ".inst 0x25717120 // psel p0.s, p12.s/Z, p9.s[w13, #1]\n" + ".inst 0xe0b1aac4 // st1w { za1v.s[x13] }, p2/Z, [x22, x17, LSL #2]\n" + "add x10, x10, #0x10\n" + ".inst 0xe0a9a6c1 // st1w { za0v.s[x13, #1] }, p1/Z, [x22, x9, LSL #2]\n" + ".inst 0xe0bba2c5 // st1w { za1v.s[x13, #1] }, p0/Z, [x22, x27, LSL #2]\n" + "add x13, x13, #0x2\n" + "addvl x22, x22, #4\n" + "cmp x13, x15\n" + "blt 4b\n" + "5:" // K loop: Main loop: First: Tail + ".inst 0x25316160 // psel p0.s, p8.s/Z, p11.s[w13]\n" + ".inst 0x25316142 // psel p2.s, p8.s/Z, p10.s[w13]\n" + ".inst 0x25716161 // psel p1.s, p8.s/Z, p11.s[w13, #1]\n" + ".inst 0x25716143 // psel p3.s, p8.s/Z, p10.s[w13, #1]\n" + "mov x11, %x[in]\n" + "add x10, %x[in], x17, LSL #3\n" + ".inst 0xe0972388 // ld1w { za2h.s[x13] }, p0/Z, [x28, x23, LSL #2]\n" + ".inst 0x25317120 // psel p0.s, p12.s/Z, p9.s[w13]\n" + "ldr x28, [x11, #0x0]\n" + "mov x12, #0x0\n" + ".inst 0xe0972b4c // ld1w { za3h.s[x13] }, p2/Z, [x26, x23, LSL #2]\n" + ".inst 0x25317122 // psel p2.s, p12.s/Z, p9.s[w13]\n" + "ldr x26, [x10, #0x0]\n" + ".inst 0xe0972709 // ld1w { za2h.s[x13, #1] }, p1/Z, [x24, x23, LSL #2]\n" + ".inst 0x25717121 // psel p1.s, p12.s/Z, p9.s[w13, #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe0972ead // ld1w { za3h.s[x13, #1] }, p3/Z, [x21, x23, LSL #2]\n" + "ldr x21, [x10, #0x8]\n" + ".inst 0xe0bfa2c0 // st1w { za0v.s[x13] }, p0/Z, [x22, XZR, LSL #2]\n" + ".inst 0x25717120 // psel p0.s, p12.s/Z, p9.s[w13, #1]\n" + ".inst 0xe0b1aac4 // st1w { za1v.s[x13] }, p2/Z, [x22, x17, LSL #2]\n" + "whilelt p9.s, x14, %x[width]\n" + "incw x14\n" + ".inst 0xe0a9a6c1 // st1w { za0v.s[x13, #1] }, p1/Z, [x22, x9, LSL #2]\n" + "add x10, x10, #0x10\n" + "incw x23\n" + ".inst 0xe0bba2c5 // st1w { za1v.s[x13, #1] }, p0/Z, [x22, x27, LSL #2]\n" + "addvl x22, x22, #4\n" + "whilelt p8.s, x14, %x[width]\n" + "cbz x15, 7f\n" + "6:" // K loop: Main loop: Second: Loop + ".inst 0x25306160 // psel p0.s, p8.s/Z, p11.s[w12]\n" + ".inst 0x25306142 // psel p2.s, p8.s/Z, p10.s[w12]\n" + ".inst 0x25706161 // psel p1.s, p8.s/Z, p11.s[w12, #1]\n" + ".inst 0x25706143 // psel p3.s, p8.s/Z, p10.s[w12, #1]\n" + ".inst 0xe0970380 // ld1w { za0h.s[x12] }, p0/Z, [x28, x23, LSL #2]\n" + ".inst 0x25307120 // psel p0.s, p12.s/Z, p9.s[w12]\n" + "ldr x28, [x11, #0x0]\n" + ".inst 0xe0970b44 // ld1w { za1h.s[x12] }, p2/Z, [x26, x23, LSL #2]\n" + ".inst 0x25307122 // psel p2.s, p12.s/Z, p9.s[w12]\n" + "ldr x26, [x10, #0x0]\n" + ".inst 0xe0970701 // ld1w { za0h.s[x12, #1] }, p1/Z, [x24, x23, LSL #2]\n" + ".inst 0x25707121 // psel p1.s, p12.s/Z, p9.s[w12, #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe0970ea5 // ld1w { za1h.s[x12, #1] }, p3/Z, [x21, x23, LSL #2]\n" + "ldr x21, [x10, #0x8]\n" + ".inst 0xe0bf82c8 // st1w { za2v.s[x12] }, p0/Z, [x22, XZR, LSL #2]\n" + ".inst 0x25707120 // psel p0.s, p12.s/Z, p9.s[w12, #1]\n" + ".inst 0xe0b18acc // st1w { za3v.s[x12] }, p2/Z, [x22, x17, LSL #2]\n" + "add x10, x10, #0x10\n" + ".inst 0xe0a986c9 // st1w { za2v.s[x12, #1] }, p1/Z, [x22, x9, LSL #2]\n" + ".inst 0xe0bb82cd // st1w { za3v.s[x12, #1] }, p0/Z, [x22, x27, LSL #2]\n" + "add x12, x12, #0x2\n" + "addvl x22, x22, #4\n" + "cmp x12, x15\n" + "blt 6b\n" + "7:" // K loop: Main loop: Second: Tail + ".inst 0x25306160 // psel p0.s, p8.s/Z, p11.s[w12]\n" + ".inst 0x25306142 // psel p2.s, p8.s/Z, p10.s[w12]\n" + ".inst 0x25706161 // psel p1.s, p8.s/Z, p11.s[w12, #1]\n" + ".inst 0x25706143 // psel p3.s, p8.s/Z, p10.s[w12, #1]\n" + "mov x11, %x[in]\n" + "add x10, %x[in], x17, LSL #3\n" + ".inst 0xe0970380 // ld1w { za0h.s[x12] }, p0/Z, [x28, x23, LSL #2]\n" + ".inst 0x25307120 // psel p0.s, p12.s/Z, p9.s[w12]\n" + "ldr x28, [x11, #0x0]\n" + ".inst 0xe0970b44 // ld1w { za1h.s[x12] }, p2/Z, [x26, x23, LSL #2]\n" + ".inst 0x25307122 // psel p2.s, p12.s/Z, p9.s[w12]\n" + "ldr x26, [x10, #0x0]\n" + ".inst 0xe0970701 // ld1w { za0h.s[x12, #1] }, p1/Z, [x24, x23, LSL #2]\n" + ".inst 0x25707121 // psel p1.s, p12.s/Z, p9.s[w12, #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe0970ea5 // ld1w { za1h.s[x12, #1] }, p3/Z, [x21, x23, LSL #2]\n" + "ldr x21, [x10, #0x8]\n" + ".inst 0xe0bf82c8 // st1w { za2v.s[x12] }, p0/Z, [x22, XZR, LSL #2]\n" + ".inst 0x25707120 // psel p0.s, p12.s/Z, p9.s[w12, #1]\n" + ".inst 0xe0b18acc // st1w { za3v.s[x12] }, p2/Z, [x22, x17, LSL #2]\n" + "whilelt p9.s, x14, %x[width]\n" + "subs x20, x20, #0x1\n" + ".inst 0xe0a986c9 // st1w { za2v.s[x12, #1] }, p1/Z, [x22, x9, LSL #2]\n" + "add x10, x10, #0x10\n" + "incw x14\n" + ".inst 0xe0bb82cd // st1w { za3v.s[x12, #1] }, p0/Z, [x22, x27, LSL #2]\n" + "addvl x22, x22, #4\n" + "incw x23\n" + "bgt 3b\n" + "8:" // K loop: Tails + "cbnz x25, 11f\n" + "mov x11, %x[in]\n" + "whilelt p8.s, x14, %x[width]\n" + "mov x12, #0x0\n" + "9:" // K loop: Tails: Even: First + ".inst 0x25307123 // psel p3.s, p12.s/Z, p9.s[w12]\n" + ".inst 0x25307122 // psel p2.s, p12.s/Z, p9.s[w12]\n" + ".inst 0x25306161 // psel p1.s, p8.s/Z, p11.s[w12]\n" + ".inst 0x25306140 // psel p0.s, p8.s/Z, p10.s[w12]\n" + ".inst 0xe0bf8ec0 // st1w { za0v.s[x12] }, p3/Z, [x22, XZR, LSL #2]\n" + ".inst 0xe0b18ac4 // st1w { za1v.s[x12] }, p2/Z, [x22, x17, LSL #2]\n" + "addvl x22, x22, #2\n" + "ldr x21, [x11, #0x0]\n" + "ldr x20, [x11, x17, LSL #0x3]\n" + "add x11, x11, #0x8\n" + ".inst 0xe09706a8 // ld1w { za2h.s[x12] }, p1/Z, [x21, x23, LSL #2]\n" + ".inst 0xe097028c // ld1w { za3h.s[x12] }, p0/Z, [x20, x23, LSL #2]\n" + "add x12, x12, #0x1\n" + "cmp x12, x17\n" + "blt 9b\n" + "whilelt p9.s, x14, %x[width]\n" + "whilelt p8.s, x14, %x[width]\n" + "mov x12, #0x0\n" + "10:" // K loop: Tails: Even: Second + ".inst 0x25307121 // psel p1.s, p12.s/Z, p9.s[w12]\n" + ".inst 0x25307120 // psel p0.s, p12.s/Z, p9.s[w12]\n" + ".inst 0xe0bf86c8 // st1w { za2v.s[x12] }, p1/Z, [x22, XZR, LSL #2]\n" + ".inst 0xe0b182cc // st1w { za3v.s[x12] }, p0/Z, [x22, x17, LSL #2]\n" + "add x12, x12, #0x1\n" + "addvl x22, x22, #2\n" + "cmp x12, x16\n" + "blt 10b\n" + "whilelt p8.s, x14, %x[width]\n" + "b 13f\n" + "11:" // K loop: Tails: Odd + "mov x12, #0x0\n" + "12:" // K loop: Tails: Odd: Loop + ".inst 0x25307121 // psel p1.s, p12.s/Z, p9.s[w12]\n" + ".inst 0x25307120 // psel p0.s, p12.s/Z, p9.s[w12]\n" + ".inst 0xe0bf86c0 // st1w { za0v.s[x12] }, p1/Z, [x22, XZR, LSL #2]\n" + ".inst 0xe0b182c4 // st1w { za1v.s[x12] }, p0/Z, [x22, x17, LSL #2]\n" + "add x12, x12, #0x1\n" + "addvl x22, x22, #2\n" + "cmp x12, x16\n" + "blt 12b\n" + "13:" // K loop: End + "mov %x[out], x22\n" + ".inst 0xd503467f // SMSTOP\n" + : [out] "+&r"(out) + : [height] "r"(height), [in] "r"(in), [row_offset] "r"(row_offset), [width] "r"(width) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", + "p14", "p15", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x20", "x21", "x22", "x23", + "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", + "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", + "z26", "z27", "z28", "z29", "z30", "z31"); + } +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.h b/kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.h new file mode 100644 index 0000000000000000000000000000000000000000..82c5db48f846b5045b17c3440447002fa0796ded --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.h @@ -0,0 +1,77 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Gets m step value. +/// +/// The starting row index must be divisible by `m_step`. +/// +/// @param[in] mr Number of rows to be interleaved. +/// +/// @return The m step value. +size_t kai_get_m_step_lhs_pack_f32p2vlx1_f32_sme(size_t mr); + +/// Gets the offset in bytes to the data element in the LHS buffer. +/// +/// @param[in] m_idx Row index. +/// @param[in] lhs_stride Row stride in bytes. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_lhs_offset_lhs_pack_f32p2vlx1_f32_sme(size_t m_idx, size_t lhs_stride); + +/// Gets the offset in bytes to the data element in the packed LHS buffer. +/// +/// @param[in] m_idx Row index in the unpacked LHS matrix. +/// @param[in] k Number of columns in the unpacked LHS matrix. +/// @param[in] mr Number of rows to be interleaved. +/// @param[in] kr Unused. Must be 1. +/// @param[in] sr Unused. Must be 1. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_lhs_packed_offset_lhs_pack_f32p2vlx1_f32_sme(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr); + +/// Gets the size in bytes of the packed LHS buffer. +/// +/// @param[in] m Number of rows in the unpacked LHS matrix. +/// @param[in] k Number of columns in the unpacked LHS matrix. +/// @param[in] mr Number of rows to be interleaved. +/// @param[in] kr Unused. Must be 1. +/// @param[in] sr Unused. Must be 1. +/// +/// @return The size in bytes of the packed LHS buffer. +size_t kai_get_lhs_packed_size_lhs_pack_f32p2vlx1_f32_sme(size_t m, size_t k, size_t mr, size_t kr, size_t sr); + +/// Runs the LHS packing function for matrix multiplication. +/// +/// The pointer of each buffers (LHS and packed LHS) needs to be added with offset +/// calculated using the following functions: +/// +/// * LHS: @ref kai_get_lhs_offset_lhs_pack_f32p2vlx1_f32_sme. +/// * Packed LHS: @ref kai_get_lhs_packed_offset_lhs_pack_f32p2vlx1_f32_sme. +/// +/// @param[in] m Number of rows of the unpacked LHS matrix. +/// @param[in] k Common dimension between the LHS and RHS matrix. +/// @param[in] mr Block size in M dimension. It must be 2 * kai_get_sme_vector_length_u32(). +/// @param[in] kr Block size in K dimension. It must be 1. +/// @param[in] sr Number of kr splits. It must be 1. +/// @param[in] m_idx_start Unused. Must be 0. +/// @param[in] lhs LHS matrix data buffer. +/// @param[in] lhs_stride Row stride in bytes of the LHS matrix. +/// @param[out] lhs_packed Packed RHS matrix. +void kai_run_lhs_pack_f32p2vlx1_f32_sme( + size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride, + void* lhs_packed); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.c b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.c new file mode 100644 index 0000000000000000000000000000000000000000..75a48fb4a704f250fd368d0e44d861933270e6c9 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.c @@ -0,0 +1,163 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural features check. + +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_nr = 2; +static const size_t kai_kr = 1; + +size_t kai_get_n_step_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(void) { + return kai_nr * kai_get_sme_vector_length_u32(); +} + +size_t kai_get_rhs_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(size_t n_idx) { + KAI_ASSUME(n_idx % (kai_nr * kai_get_sme_vector_length_u32()) == 0); + + return n_idx * sizeof(uint32_t); +} + +size_t kai_get_bias_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(size_t n_idx) { + return n_idx * sizeof(uint32_t); +} + +size_t kai_get_rhs_packed_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(size_t n_idx, size_t k) { + KAI_ASSUME(n_idx % (kai_nr * kai_get_sme_vector_length_u32()) == 0); + + return n_idx * (sizeof(uint32_t) + k * sizeof(uint32_t)); +} + +size_t kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(size_t n, size_t k) { + return kai_get_rhs_packed_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme( + kai_roundup(n, kai_nr * kai_get_sme_vector_length_u32()), k); +} + +void kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, + const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params) { + KAI_ASSUME(num_groups == 1); + KAI_ASSUME(nr == kai_nr * kai_get_sme_vector_length_u32()); + KAI_ASSUME(kr == kai_kr); + KAI_ASSUME(sr == 1); + KAI_ASSUME(rhs != NULL); + KAI_ASSUME(bias != NULL); + KAI_ASSUME(scale == NULL); + KAI_ASSUME(rhs_packed != NULL); + KAI_ASSUME(extra_bytes == 0); + KAI_ASSUME(params == NULL); + + size_t height = k; + const size_t width = n; + const void* in = rhs; + void* out = rhs_packed; + const size_t in_stride = rhs_stride; + size_t out_stride = kai_nr * kai_get_sme_vector_length_u8() * (height + sizeof(uint32_t) / sizeof(uint32_t)); + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "mov x22, %x[out]\n" + "mov x21, %x[width]\n" + "ptrue p2.b\n" + "1:" // Bias: Full loop + "mov x20, x21\n" + "decw x21, ALL, MUL #2\n" + "whilelt p1.s, XZR, x20\n" + "decw x20\n" + "whilelt p0.s, XZR, x20\n" + "cmp x21, #0x0\n" + "ld1w { z17.s }, p1/Z, [%x[bias]]\n" + "ld1w { z16.s }, p0/Z, [%x[bias], #1, MUL VL]\n" + "incb %x[bias], ALL, MUL #2\n" + "st1w { z17.s }, p2, [x22]\n" + "st1w { z16.s }, p2, [x22, #1, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bgt 1b\n" + "cmp %x[height], #0x4\n" + "incb %x[out], ALL, MUL #2\n" + "blt 5f\n" + "2:" // Main row loop: Head + "mov x26, %x[in]\n" + "mov x25, %x[out]\n" + "add x24, x26, %x[in_stride]\n" + "sub %x[height], %x[height], #0x4\n" + "add x23, x24, %x[in_stride]\n" + "mov x22, %x[width]\n" + "add x21, x23, %x[in_stride]\n" + "add %x[in], x21, %x[in_stride]\n" + "3:" // Main row loop: Column loop + "mov x20, x22\n" + "decw x22, ALL, MUL #2\n" + "whilelt p1.s, XZR, x20\n" + "decw x20\n" + "whilelt p0.s, XZR, x20\n" + "cmp x22, #0x0\n" + "ld1w { z23.s }, p1/Z, [x26]\n" + "ld1w { z22.s }, p0/Z, [x26, #1, MUL VL]\n" + "addvl x26, x26, #2\n" + "ld1w { z21.s }, p1/Z, [x24]\n" + "ld1w { z20.s }, p0/Z, [x24, #1, MUL VL]\n" + "addvl x24, x24, #2\n" + "ld1w { z19.s }, p1/Z, [x23]\n" + "ld1w { z18.s }, p0/Z, [x23, #1, MUL VL]\n" + "addvl x23, x23, #2\n" + "ld1w { z17.s }, p1/Z, [x21]\n" + "ld1w { z16.s }, p0/Z, [x21, #1, MUL VL]\n" + "addvl x21, x21, #2\n" + "st1w { z23.s }, p2, [x25]\n" + "st1w { z22.s }, p2, [x25, #1, MUL VL]\n" + "st1w { z21.s }, p2, [x25, #2, MUL VL]\n" + "st1w { z20.s }, p2, [x25, #3, MUL VL]\n" + "st1w { z19.s }, p2, [x25, #4, MUL VL]\n" + "st1w { z18.s }, p2, [x25, #5, MUL VL]\n" + "st1w { z17.s }, p2, [x25, #6, MUL VL]\n" + "st1w { z16.s }, p2, [x25, #7, MUL VL]\n" + "add x25, x25, %x[out_stride]\n" + "bgt 3b\n" + "cmp %x[height], #0x4\n" + "addvl %x[out], %x[out], #8\n" + "bge 2b\n" + "cbz %x[height], 9f\n" + "5:" // Main loop skip + "6:" // Tail row loop: Head + "mov x26, %x[in]\n" + "mov x25, %x[out]\n" + "add %x[in], x26, %x[in_stride]\n" + "sub %x[height], %x[height], #0x1\n" + "mov x21, %x[width]\n" + "7:" // Tail row loop: Column loop + "mov x20, x21\n" + "decw x21, ALL, MUL #2\n" + "whilelt p1.s, XZR, x20\n" + "decw x20\n" + "whilelt p0.s, XZR, x20\n" + "cmp x21, #0x0\n" + "ld1w { z17.s }, p1/Z, [x26]\n" + "ld1w { z16.s }, p0/Z, [x26, #1, MUL VL]\n" + "addvl x26, x26, #2\n" + "st1w { z17.s }, p2, [x25]\n" + "st1w { z16.s }, p2, [x25, #1, MUL VL]\n" + "add x25, x25, %x[out_stride]\n" + "bgt 7b\n" + "cmp %x[height], #0x1\n" + "addvl %x[out], %x[out], #2\n" + "bge 6b\n" + "9:" // Done + ".inst 0xd503467f // SMSTOP\n" + : [bias] "+&r"(bias), [height] "+&r"(height), [in] "+&r"(in), [out] "+&r"(out) + : [in_stride] "r"(in_stride), [out_stride] "r"(out_stride), [width] "r"(width) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", + "p15", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", + "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", + "z25", "z26", "z27", "z28", "z29", "z30", "z31"); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h new file mode 100644 index 0000000000000000000000000000000000000000..85d6dd406a15f7707c927d91ea4b8f7300db7228 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h @@ -0,0 +1,80 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Gets n step value. +/// +/// The starting row index must be divisible by `n_step`. +/// +/// @return The n step value. +size_t kai_get_n_step_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(void); + +/// Gets the offset in bytes to the data element in the RHS matrix buffer. +/// +/// @param[in] n_idx Column index. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(size_t n_idx); + +/// Gets the offset in bytes to the data element in the bias buffer. +/// +/// @param[in] n_idx Column index. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_bias_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(size_t n_idx); + +/// Gets the offset in bytes to the data element in the packed RHS buffer. +/// +/// @param[in] n_idx Row index. +/// @param[in] k Number of columns. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_packed_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(size_t n_idx, size_t k); + +/// Gets the size in bytes of the packed RHS buffer. +/// +/// @param[in] n Number of rows. +/// @param[in] k Number of columns. +/// +/// @return The size in bytes of the packed RHS buffer. +size_t kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(size_t n, size_t k); + +/// Runs the RHS packing function for matrix multiplication. +/// +/// The pointer of each buffers (RHS, bias and packed RHS) needs to be added with offset +/// calculated using the following functions: +/// +/// * RHS: @ref kai_get_rhs_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme. +/// * Bias: @ref kai_get_packed_rhs_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme. +/// * Output: @ref kai_get_dst_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme. +/// +/// @param[in] num_groups Number of groups. It must be 1. +/// @param[in] n Number of columns of the output matrix. +/// @param[in] k Common dimension between the LHS and RHS matrix. +/// @param[in] nr Block size in N dimension. It must be 2 * kai_get_sme_vector_length_u32(). +/// @param[in] kr Block size in K dimension. It must be 1. +/// @param[in] sr Number of kr splits. It must be 1. +/// @param[in] rhs_stride Row stride in bytes of the RHS matrix. +/// @param[in] rhs RHS matrix data buffer. +/// @param[in] bias Bias matrix data buffer. +/// @param[in] scale Scale data buffer. It must be NULL. +/// @param[out] rhs_packed Packed RHS matrix. +/// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. It must be 0. +/// @param[in] params Extra packing parameters. It must be NULL. +void kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, + const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/test/common/compare.cpp b/test/common/compare.cpp index 91b6bc6656c6a6a95bbb8e465ab22760e4e3884f..54af776fd6937a5dd4deaa3469977a7ec1db3a83 100644 --- a/test/common/compare.cpp +++ b/test/common/compare.cpp @@ -47,23 +47,45 @@ std::tuple calculate_error(T imp, T ref) { /// Compares matrices with per-row quantization. template bool compare_raw( - const void* imp_data, const void* ref_data, size_t full_height, size_t full_width, const Rect& rect, - MismatchHandler& handler) { - for (size_t y = 0; y < full_height; ++y) { - for (size_t x = 0; x < full_width; ++x) { - const auto in_roi = - y >= rect.start_row() && y < rect.end_row() && x >= rect.start_col() && x < rect.end_col(); + const void* imp_data, const void* ref_data, const DataFormat& format, size_t full_height, size_t full_width, + const Rect& rect, MismatchHandler& handler) { + const auto block_height = format.actual_block_height(full_height); + const auto block_width = format.actual_block_width(full_width); + const auto subblock_height = format.actual_subblock_height(full_height); + const auto subblock_width = format.actual_subblock_width(full_width); - const auto imp_value = read_array(imp_data, y * full_width + x); - const auto ref_value = in_roi ? read_array(ref_data, y * full_width + x) : static_cast(0); + size_t idx = 0; - const auto [abs_err, rel_err] = calculate_error(imp_value, ref_value); + for (size_t y_block = 0; y_block < full_height; y_block += block_height) { + for (size_t x_block = 0; x_block < full_width; x_block += block_width) { + for (size_t y_subblock = 0; y_subblock < block_height; y_subblock += subblock_height) { + for (size_t x_subblock = 0; x_subblock < block_width; x_subblock += subblock_width) { + for (size_t y_element = 0; y_element < subblock_height; ++y_element) { + for (size_t x_element = 0; x_element < subblock_width; ++x_element) { + const auto y = y_block + y_subblock + y_element; + const auto x = x_block + x_subblock + x_element; - if (abs_err != 0 || rel_err != 0) { - const auto notifying = !in_roi || handler.handle_data(abs_err, rel_err); + const auto in_roi = y >= rect.start_row() && y < rect.end_row() && x >= rect.start_col() && + x < rect.end_col(); + + const auto imp_value = read_array(imp_data, idx); + const auto ref_value = in_roi ? read_array(ref_data, idx) : static_cast(0); - if (notifying) { - KAI_LOGE("Mismatched data at (", y, ", ", x, "): actual = ", imp_value, ", expected: ", ref_value); + const auto [abs_err, rel_err] = calculate_error(imp_value, ref_value); + + if (abs_err != 0 || rel_err != 0) { + const auto notifying = !in_roi || handler.handle_data(abs_err, rel_err); + + if (notifying) { + KAI_LOGE( + "Mismatched data at (", y, ", ", x, "): actual = ", imp_value, + ", expected: ", ref_value); + } + } + + ++idx; + } + } } } } @@ -186,10 +208,10 @@ bool compare( case DataFormat::PackFormat::NONE: switch (data_type) { case DataType::FP32: - return compare_raw(imp_data, ref_data, full_height, full_width, rect, handler); + return compare_raw(imp_data, ref_data, format, full_height, full_width, rect, handler); case DataType::FP16: - return compare_raw(imp_data, ref_data, full_height, full_width, rect, handler); + return compare_raw(imp_data, ref_data, format, full_height, full_width, rect, handler); default: break; @@ -201,6 +223,9 @@ bool compare( if (data_type == DataType::FP16 && offset_dt == DataType::FP16) { return compare_per_row( imp_data, ref_data, format, full_height, full_width, rect, handler); + } else if (data_type == DataType::FP32 && offset_dt == DataType::FP32) { + return compare_per_row( + imp_data, ref_data, format, full_height, full_width, rect, handler); } else if (data_type == DataType::BF16 && offset_dt == DataType::FP32) { return compare_per_row( imp_data, ref_data, format, full_height, full_width, rect, handler); diff --git a/test/common/cpu_info.cpp b/test/common/cpu_info.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6fa59677fd6e02d07ca1aa7d2f6657bfa3b61a72 --- /dev/null +++ b/test/common/cpu_info.cpp @@ -0,0 +1,85 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/common/cpu_info.hpp" + +#include +#include +#include + +#include "kai/kai_common.h" + +#if defined(__aarch64__) && defined(__linux__) +#include +#endif // defined(__aarch64__) && defined(__linux__) + +#if defined(__aarch64__) && defined(__APPLE__) +#include +#include +#endif // defined(__aarch64__) && defined(__APPLE__) + +namespace kai::test { + +namespace { + +#if defined(__aarch64__) && defined(__linux__) +constexpr uint64_t A64_HWCAP2_SME = 1UL << 23; +constexpr uint64_t A64_HWCAP2_SME2 = 1UL << 37; +#endif // defined(__aarch64__) && defined(__linux__) + +#if defined(__aarch64__) && defined(__APPLE__) +template +T get_sysctl_by_name(std::string_view name) { + T value{}; + size_t size = sizeof(T); + + KAI_ASSERT(sysctlbyname(name.data(), nullptr, &size, nullptr, 0) == 0); + KAI_ASSERT(size == sizeof(T)); + + [[maybe_unused]] int status = sysctlbyname(name.data(), &value, &size, nullptr, 0); + KAI_ASSERT(status == 0); + + return value; +} +#endif // defined(__aarch64__) && defined(__APPLE__) + +/// Information about the CPU that is executing the program. +struct CpuInfo { + CpuInfo() { +#if defined(__aarch64__) && defined(__linux__) + const uint64_t hwcaps2 = getauxval(AT_HWCAP2); + + has_sme = (hwcaps2 & A64_HWCAP2_SME) != 0; + has_sme2 = (hwcaps2 & A64_HWCAP2_SME2) != 0; +#endif // defined(__aarch64__) && defined(__linux__) + +#if defined(__aarch64__) && defined(__APPLE__) + has_sme = get_sysctl_by_name("hw.optional.arm.FEAT_SME") == 1; + has_sme2 = get_sysctl_by_name("hw.optional.arm.FEAT_SME2") == 1; +#endif // defined(__aarch64__) && defined(__APPLE__) + } + + /// Gets the singleton @ref CpuInfo object. + static CpuInfo& current() { + static CpuInfo cpu_info{}; + return cpu_info; + } + + bool has_sme{}; ///< FEAT_SME is supported. + bool has_sme2{}; ///< FEAT_SME2 is supported. +}; + +} // namespace + +bool cpu_has_sme() { + return CpuInfo::current().has_sme; +} + +bool cpu_has_sme2() { + return CpuInfo::current().has_sme2; +} + +} // namespace kai::test diff --git a/test/common/cpu_info.hpp b/test/common/cpu_info.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b231ecea2d651f20b6b340ed5fa3a1dadaef1b42 --- /dev/null +++ b/test/common/cpu_info.hpp @@ -0,0 +1,17 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +namespace kai::test { + +/// Returns a value indicating whether the current CPU supports FEAT_SME. +bool cpu_has_sme(); + +/// Returns a value indicating whether the current CPU supports FEAT_SME2. +bool cpu_has_sme2(); + +} // namespace kai::test diff --git a/test/common/data_format.cpp b/test/common/data_format.cpp index a16269e9dea6646a5f95852426ea672983dff9ab..296994f5c27e52ae3c7a8b9717c9a952b1e1b48a 100644 --- a/test/common/data_format.cpp +++ b/test/common/data_format.cpp @@ -129,7 +129,7 @@ uintptr_t DataFormat::default_row_stride(size_t width) const { switch (_pack_format) { case PackFormat::NONE: - return padded_width * data_type_size_in_bits(_data_type) / 8; + return (_block_height > 0 ? _block_height : 1) * padded_width * data_type_size_in_bits(_data_type) / 8; case PackFormat::BIAS_PER_ROW: KAI_ASSUME(_block_height > 0); diff --git a/test/common/printer.cpp b/test/common/printer.cpp index 7d7144d836bcad603a3a6fdf8c1b37cc6c8f1a6a..9fc3e0d8fb096f414b2df8d87e36efba6c23e940 100644 --- a/test/common/printer.cpp +++ b/test/common/printer.cpp @@ -63,14 +63,60 @@ inline void print_data(std::ostream& os, const uint8_t* data, size_t len, DataTy } } -void print_matrix_raw(std::ostream& os, const uint8_t* data, DataType data_type, size_t height, size_t width) { - const auto row_stride = width * data_type_size_in_bits(data_type) / 8; +void print_matrix_raw(std::ostream& os, const uint8_t* data, const DataFormat& format, size_t height, size_t width) { + const auto data_type = format.data_type(); + const auto esize_bits = data_type_size_in_bits(data_type); + const auto block_height = format.actual_block_height(height); + const auto block_width = format.actual_block_width(width); + const auto subblock_height = format.actual_subblock_height(height); + const auto subblock_width = format.actual_subblock_width(width); os << "[\n"; - for (size_t y = 0; y < height; ++y) { - os << " ["; - print_data(os, data + y * row_stride, width, data_type); - os << "],\n"; + for (size_t y_block = 0; y_block < height; y_block += block_height) { + if (block_height != height) { + os << " [\n"; + } + + for (size_t x_block = 0; x_block < width; x_block += block_width) { + if (block_width != width) { + os << " [\n"; + } + + for (size_t y_subblock = 0; y_subblock < block_height; y_subblock += subblock_height) { + if (subblock_height != block_height) { + os << " [\n"; + } + + for (size_t x_subblock = 0; x_subblock < block_width; x_subblock += subblock_width) { + if (subblock_width != block_width) { + os << " [\n"; + } + + for (size_t y = 0; y < subblock_height; ++y) { + os << " ["; + print_data(os, data, subblock_width, data_type); + data += subblock_width * esize_bits / 8; + os << "],\n"; + } + + if (subblock_width != block_width) { + os << " ]\n"; + } + } + + if (subblock_height != block_height) { + os << " ]\n"; + } + } + + if (block_width != width) { + os << " ],\n"; + } + } + + if (block_height != height) { + os << " ],\n"; + } } os << "]\n"; } @@ -84,7 +130,7 @@ void print_matrix_per_row( const auto num_blocks = (height + block_height - 1) / block_height; KAI_ASSUME(format.default_size_in_bytes(height, width) % num_blocks == 0); - const auto block_data_bytes = format.default_size_in_bytes(height, width) / num_blocks; + const auto block_data_bytes = block_height * width * data_type_size_in_bits(format.data_type()) / 8; const auto block_offsets_bytes = block_height * data_type_size_in_bits(format.zero_point_data_type()) / 8; const auto block_scales_bytes = has_scale ? block_height * data_type_size_in_bits(format.scale_data_type()) / 8 : 0; @@ -115,7 +161,7 @@ void print_matrix( switch (format.pack_format()) { case DataFormat::PackFormat::NONE: - print_matrix_raw(os, reinterpret_cast(data), format.data_type(), height, width); + print_matrix_raw(os, reinterpret_cast(data), format, height, width); break; case DataFormat::PackFormat::BIAS_PER_ROW: diff --git a/test/common/sme.cpp b/test/common/sme.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b13e991014567bd5298c34ecf5b790d7a9df26d1 --- /dev/null +++ b/test/common/sme.cpp @@ -0,0 +1,88 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE) +#error This file must be compiled for AArch64, FEAT_SVE. +#else // Architectural features check. + +#include "test/common/sme.hpp" + +#include "test/common/cpu_info.hpp" + +namespace kai::test { + +template <> +uint64_t get_sme_vector_length<1>() { + static uint64_t res = 0; + + if (res == 0) { + if (cpu_has_sme()) { + __asm __volatile( + ".inst 0xd503477f // SMSTART ZA\n" + "cntb %0\n" + ".inst 0xd503467f // SMSTOP\n" + : "=r"(res) + : + : "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", + "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", + "z30", "z31"); + } else { + res = 1; + } + } + + return res; +} + +template <> +uint64_t get_sme_vector_length<2>() { + static uint64_t res = 0; + + if (res == 0) { + if (cpu_has_sme()) { + __asm __volatile( + ".inst 0xd503477f // SMSTART ZA\n" + "cnth %0\n" + ".inst 0xd503467f // SMSTOP\n" + : "=r"(res) + : + : "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", + "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", + "z30", "z31"); + } else { + res = 1; + } + } + + return res; +} + +template <> +uint64_t get_sme_vector_length<4>() { + static uint64_t res = 0; + + if (res == 0) { + if (cpu_has_sme()) { + __asm __volatile( + ".inst 0xd503477f // SMSTART ZA\n" + "cntw %0\n" + ".inst 0xd503467f // SMSTOP\n" + : "=r"(res) + : + : "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", + "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", + "z30", "z31"); + } else { + res = 1; + } + } + + return res; +} + +} // namespace kai::test + +#endif // Architectural features check. diff --git a/test/common/sme.hpp b/test/common/sme.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0ea65fab74abdc6c00f305e28c0d7a482b9679e8 --- /dev/null +++ b/test/common/sme.hpp @@ -0,0 +1,24 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +namespace kai::test { + +/// Gets the SME vector length. +template +uint64_t get_sme_vector_length(); + +/// Gets the SME vector length. +template +uint64_t get_sme_vector_length() { + return get_sme_vector_length(); +} + +} // namespace kai::test diff --git a/test/reference/pack.cpp b/test/reference/pack.cpp index 261d7a53cdaf0e614285263ec3b97c0283521e8f..fb0f81b19b563dbb26c700aef274b4d6568ba4cc 100644 --- a/test/reference/pack.cpp +++ b/test/reference/pack.cpp @@ -24,6 +24,46 @@ namespace kai::test { namespace { +std::vector pack_block( + const void* src, size_t data_esize, size_t full_height, size_t full_width, size_t block_height, size_t block_width, + size_t subblock_height, size_t subblock_width) { + const auto dst_bytes = + round_up_multiple(full_height, block_height) * round_up_multiple(full_width, block_width) * data_esize; + + std::vector dst; + dst.resize(dst_bytes); + + const auto* src_ptr = reinterpret_cast(src); + auto* dst_ptr = dst.data(); + + for (size_t y_block = 0; y_block < full_height; y_block += block_height) { + for (size_t x_block = 0; x_block < full_width; x_block += block_width) { + for (size_t y_subblock = 0; y_subblock < block_height; y_subblock += subblock_height) { + for (size_t x_subblock = 0; x_subblock < block_width; x_subblock += subblock_width) { + for (size_t y_element = 0; y_element < subblock_height; ++y_element) { + if (y_block + y_subblock + y_element < full_height) { + const auto len = std::min(subblock_width, full_width - x_block - x_subblock); + + memcpy( + dst_ptr, + src_ptr + + ((y_block + y_subblock + y_element) * full_width + x_block + x_subblock) * + data_esize, + len * data_esize); + } + + dst_ptr += subblock_width * data_esize; + } + } + } + } + } + + KAI_ASSERT(reinterpret_cast(dst_ptr) - reinterpret_cast(dst.data()) == dst_bytes); + + return dst; +} + /// Packs the matrix from raw to per-row bias format. std::vector pack_bias_per_row( size_t data_esize, size_t zero_point_esize, const void* src, const void* bias, size_t height, size_t width, @@ -269,6 +309,17 @@ std::vector pack( } } + if (src_qf == DataFormat::PackFormat::NONE && dst_qf == DataFormat::PackFormat::NONE) { + KAI_ASSUME(src_dt == dst_dt); + + const auto data_esize = data_type_size_in_bits(dst_dt); + + if (data_esize % 8 == 0) { + return pack_block( + src, data_esize / 8, height, width, block_height, block_width, subblock_height, subblock_width); + } + } + KAI_ERROR("Unsupported operation!"); } diff --git a/test/tests/matmul_test.cpp b/test/tests/matmul_test.cpp index e6b6b7cf4de0c54a3baf6fb396a02f24c4cfd69a..9f2e67f13fce9a6e60ca4bac968ed14448b45a4a 100644 --- a/test/tests/matmul_test.cpp +++ b/test/tests/matmul_test.cpp @@ -21,17 +21,25 @@ #include #include "kai/kai_common.h" -#include "kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.h" -#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.h" #include "test/common/compare.hpp" #include "test/common/data_format.hpp" #include "test/common/data_type.hpp" #include "test/common/float16.hpp" #include "test/common/matrix_portion.hpp" #include "test/common/printer.hpp" +#include "test/common/sme.hpp" #include "test/reference/fill.hpp" #include "test/reference/pack.hpp" +// matmul_nt_nt_fp16_fp16_fp16_6x16_neon_mla +#include "kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.h" + +// matmul_nt_nt_fp32_fp32_fp32_2vlx2vl_sme2_mopa +#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h" +#include "kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h" + namespace kai::test { // NOLINTBEGIN(misc-non-private-member-variables-in-classes) @@ -112,11 +120,14 @@ struct MatMulMethod { /// Gets the size in bytes of the packed LHS matrix. /// - /// @param[in] m Size of the matrix in M dimension. - /// @param[in] k Size of the matrix in K dimension. + /// @param[in] m Number of rows in the unpacked LHS matrix. + /// @param[in] k Number of columns in the unpacked LHS matrix. + /// @param[in] mr Number of rows to be interleaved. + /// @param[in] kr Unused. Must be 1. + /// @param[in] sr Unused. Must be 1. /// /// @return The size in bytes. - std::function fn_get_packed_lhs_size; + std::function fn_get_packed_lhs_size; /// Gets the offset in bytes of the packed LHS matrix. /// @@ -128,12 +139,19 @@ struct MatMulMethod { /// Preprocesses the LHS matrix. /// - /// @param[in] m Size of the matrix in M dimension. - /// @param[in] k Size of the matrix in K dimension. + /// @param[in] m Number of rows of the unpacked LHS matrix. + /// @param[in] k Common dimension between the LHS and RHS matrix. + /// @param[in] mr Block size in M dimension. It must be {{ kernel.interleave_by }}VL. + /// @param[in] kr Block size in K dimension. It must be {{ kernel.block_by }}. + /// @param[in] sr Number of kr splits. It must be 1. + /// @param[in] m_idx_start Unused. Must be 0. /// @param[in] lhs LHS matrix data buffer. - /// @param[in] lhs_row_stride Row stride in bytes of the LHS matrix. - /// @param[out] packed_lhs Packed LHS matrix data buffer. - std::function fn_pack_lhs; + /// @param[in] lhs_stride Row stride in bytes of the LHS matrix. + /// @param[out] lhs_packed Packed RHS matrix. + std::function + fn_pack_lhs; /// Gets a value indicating whether LHS packing is needed. [[nodiscard]] bool is_pack_lhs_needed() const { @@ -220,6 +238,23 @@ struct MatMulMethod { Float16 clamp_min, Float16 clamp_max)> fn_main_hybrid_fp16; + /// Runs the matrix multiplication microkernel followed by a clamp operation. + /// + /// @param[in] m Number of output rows to be computed. + /// @param[in] n Number of output columns to be computed. + /// @param[in] k Common dimension of the LHS and RHS operands. + /// @param[in] packed_lhs Packed LHS matrix buffer. + /// @param[in] packed_rhs Packed RHS matrix buffer. + /// @param[out] dst Output matrix buffer. + /// @param[in] dst_stride_row Row stride in bytes of the output matrix. + /// @param[in] dst_stride_col Column stride in bytes of the output matrix. + /// @param[in] clamp_min Minimum value to clamp the final result. + /// @param[in] clamp_max Maximum value to clamp the final result. + std::function + fn_main_interleave_fp32; + /// Gets a value indicating whether pre-processing the RHS matrix is needed. [[nodiscard]] bool is_pack_rhs_needed() const { return fn_pack_rhs != nullptr; @@ -255,7 +290,7 @@ struct MatMulMethod { } [[nodiscard]] bool has_main_kernel() const { - return fn_main_hybrid_fp16 != nullptr; + return fn_main_hybrid_fp16 != nullptr || fn_main_interleave_fp32 != nullptr; } void main_kernel( @@ -268,6 +303,8 @@ struct MatMulMethod { fn_main_hybrid_fp16( m, n, k, lhs, lhs_stride, rhs, dst, dst_stride, sizeof(Float16), static_cast(clamp_min), static_cast(clamp_max)); + } else if (fn_main_interleave_fp32) { + fn_main_interleave_fp32(m, n, k, lhs, rhs, dst, dst_stride, sizeof(float), clamp_min, clamp_max); } else { KAI_ERROR("Main kernel is not available!"); } @@ -321,6 +358,56 @@ static const std::array matmul_methods = { .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, .fn_main_hybrid_fp16 = kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, + .fn_main_interleave_fp32 = nullptr, + }, + + MatMulMethod{ + .name = "matmul_nt_nt_fp32_fp32_fp32_2vlx2vl_sme2_mopa", + + .m0 = 2 * get_sme_vector_length(), + .n0 = 2 * get_sme_vector_length(), + + .lhs_transposed = false, + .rhs_transposed = false, + + .dst_format = DataFormat(DataType::FP32), + .lhs_format = DataFormat(DataType::FP32), + .packed_lhs_format = DataFormat(DataType::FP32, 2 * get_sme_vector_length(), 1), + .rhs_format = DataFormat(DataType::FP32), + .packed_rhs_format = DataFormat( + DataType::FP32, 2 * get_sme_vector_length(), 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, + DataType::UNKNOWN, 2 * get_sme_vector_length(), 1), + .bias_format = DataFormat(DataType::FP32), + + .fn_get_mr = kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + .fn_get_nr = kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + .fn_get_kr = kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + .fn_get_sr = kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + + .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme, + .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + + .fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_f32p2vlx1_f32_sme, + .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_f32p2vlx1_f32_sme, + .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + .fn_pack_lhs = kai_run_lhs_pack_f32p2vlx1_f32_sme, + + .fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme, + .fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme, + .fn_get_pack_rhs_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + .fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + .fn_pack_rhs = kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme, + + .fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme, + + .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + + .fn_main_hybrid_fp16 = nullptr, + .fn_main_interleave_fp32 = kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, }, }; @@ -391,7 +478,8 @@ protected: std::vector ref_packed_lhs; if (has_lhs_pack) { - pack(method.packed_lhs_format, lhs.data(), nullptr, nullptr, method.lhs_format, lhs_h, lhs_w); + ref_packed_lhs = + pack(method.packed_lhs_format, lhs.data(), nullptr, nullptr, method.lhs_format, lhs_h, lhs_w); } const auto rhs_h = method.rhs_transposed ? info.n : info.k; @@ -467,7 +555,7 @@ TEST_P(MatMulTest, PackedLhs) { const auto rect = portion.compute_portion( lhs_h, lhs_w, method.packed_lhs_format.scheduler_block_height(lhs_h), - method.packed_lhs_format.scheduler_block_width(lhs_w)); + lhs_w); // LHS packing micro-kernel API doesn't support scheduling over K dimension. if (rect.height() == 0 || rect.width() == 0) { GTEST_SKIP(); @@ -475,12 +563,12 @@ TEST_P(MatMulTest, PackedLhs) { const auto ref_lhs_row_stride = method.lhs_format.default_row_stride(lhs_w); - const auto packed_lhs_size = method.fn_get_packed_lhs_size(info.m, info.k); + const auto packed_lhs_size = method.fn_get_packed_lhs_size(info.m, info.k, method.m0, 1, 1); const auto ref_packed_lhs_size = method.packed_lhs_format.default_size_in_bytes(lhs_h, lhs_w); ASSERT_EQ(packed_lhs_size, ref_packed_lhs_size); const auto lhs_offset = method.fn_get_lhs_offset(rect.start_row(), ref_lhs_row_stride); - const auto ref_lhs_offset = method.lhs_format.default_offset_in_bytes(rect.start_row(), 0, lhs_w); + const auto ref_lhs_offset = method.lhs_format.default_offset_in_bytes(rect.start_row(), rect.start_col(), lhs_w); ASSERT_EQ(lhs_offset, ref_lhs_offset); const auto packed_lhs_offset = method.fn_get_packed_lhs_offset(rect.start_row(), info.k); @@ -490,7 +578,7 @@ TEST_P(MatMulTest, PackedLhs) { std::vector packed_lhs; packed_lhs.resize(packed_lhs_size); method.fn_pack_lhs( - rect.height(), rect.width(), data.lhs.data() + lhs_offset, ref_lhs_row_stride, + rect.height(), rect.width(), method.m0, 1, 1, 0, data.lhs.data() + lhs_offset, ref_lhs_row_stride, packed_lhs.data() + packed_lhs_offset); DefaultMismatchHandler handler(0, 0.0001, 0, 0.001);