From 89940cfa46aefb7d135e2ae08bad38ebf5e2057e Mon Sep 17 00:00:00 2001 From: Gunes Bayir Date: Tue, 24 Sep 2024 12:24:12 +0100 Subject: [PATCH 01/10] Add bf16 interleaved gemm kernels Signed-off-by: Gunes Bayir --- CMakeLists.txt | 8 + kai/kai_common.h | 2 +- kai/ukernels/matmul/BUILD.bazel | 31 + ...16_bf16_f32p12x1biasf32_8x12x4_neon_mmla.c | 596 ++++++++++++++++++ ...16_bf16_f32p12x1biasf32_8x12x4_neon_mmla.h | 122 ++++ .../matmul_clamp_bf16_bf16_f32p_interface.h | 57 ++ .../pack/kai_lhs_pack_8x4_f32_bf16_neon.c | 222 +++++++ .../pack/kai_lhs_pack_8x4_f32_bf16_neon.h | 31 + ...s_bias_bf16p16x4zf32_bf16_f32_neon_nr_12.c | 474 ++++++++++++++ ...s_bias_bf16p16x4zf32_bf16_f32_neon_nr_12.h | 93 +++ 10 files changed, 1635 insertions(+), 1 deletion(-) create mode 100644 kai/ukernels/matmul/matmul_clamp_bf16_bf16_f32p/kai_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla.c create mode 100644 kai/ukernels/matmul/matmul_clamp_bf16_bf16_f32p/kai_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla.h create mode 100644 kai/ukernels/matmul/matmul_clamp_bf16_bf16_f32p/matmul_clamp_bf16_bf16_f32p_interface.h create mode 100644 kai/ukernels/matmul/pack/kai_lhs_pack_8x4_f32_bf16_neon.c create mode 100644 kai/ukernels/matmul/pack/kai_lhs_pack_8x4_f32_bf16_neon.h create mode 100644 kai/ukernels/matmul/pack/kai_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12.c create mode 100644 kai/ukernels/matmul/pack/kai_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 642fce8d..02f535e6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -88,6 +88,12 @@ set(KLEIDIAI_FILES_NEON_FP16 kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.c ) +set(KLEIDIAI_FILES_NEON_BF16 + kai/ukernels/matmul/pack/kai_lhs_pack_8x4_f32_bf16_neon.c + kai/ukernels/matmul/pack/kai_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12.c + kai/ukernels/matmul/matmul_clamp_bf16_bf16_f32p/kai_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla.c +) + set(KLEIDIAI_FILES_NEON kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.c kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.c @@ -137,6 +143,7 @@ target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_SCALAR}) if((CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") AND NOT MSVC) target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON}) target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_FP16}) + target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_BF16}) target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_DOTPROD}) target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_I8MM}) target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_SME}) @@ -145,6 +152,7 @@ if((CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL set_source_files_properties(${KLEIDIAI_FILES_SCALAR} PROPERTIES COMPILE_OPTIONS -march=armv8-a${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set_source_files_properties(${KLEIDIAI_FILES_NEON} PROPERTIES COMPILE_OPTIONS -march=armv8-a${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set_source_files_properties(${KLEIDIAI_FILES_NEON_FP16} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+fp16${KLEIDIAI_INTERNAL_EXTRA_ARCH}) + set_source_files_properties(${KLEIDIAI_FILES_NEON_BF16} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+bf16${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set_source_files_properties(${KLEIDIAI_FILES_NEON_DOTPROD} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+dotprod${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set_source_files_properties(${KLEIDIAI_FILES_NEON_I8MM} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+i8mm${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set_source_files_properties(${KLEIDIAI_FILES_SME} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+sve+sve2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) diff --git a/kai/kai_common.h b/kai/kai_common.h index 27034185..8fe70424 100644 --- a/kai/kai_common.h +++ b/kai/kai_common.h @@ -23,7 +23,7 @@ extern "C" { #define KAI_ERROR(msg) \ do { \ fflush(stdout); \ - fprintf(stderr, "%s:%d %s", __FILE__, __LINE__, msg); \ + fprintf(stderr, "%s:%d %s\n", __FILE__, __LINE__, msg); \ exit(EXIT_FAILURE); \ } while (0) diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 7bcbc1a8..f9dc2c13 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -9,6 +9,7 @@ load( "kai_c_library", "kai_cpu_dotprod", "kai_cpu_fp16", + "kai_cpu_bf16", "kai_cpu_i8mm", "kai_cpu_neon", "kai_cpu_sme", @@ -32,6 +33,22 @@ kai_c_library( ], ) +kai_c_library( + name = "clamp_bf16_bf16_f32p_interface", + hdrs = ["matmul_clamp_bf16_bf16_f32p/kai_matmul_clamp_bf16_bf16_f32p_interface.h"], + cpu_uarch = kai_cpu_bf16(), +) + +kai_c_library( + name = "clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla", + srcs = ["matmul_clamp_f16_f16_f16p/kai_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla.c"], + hdrs = ["matmul_clamp_f16_f16_f16p/kai_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla.h"], + cpu_uarch = kai_cpu_bf16(), + deps = [ + ":clamp_bf16_bf16_f32p_interface", + ], +) + kai_c_library( name = "clamp_f32_f32_f32p", srcs = ["matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.c"], @@ -159,6 +176,13 @@ kai_c_library( cpu_uarch = kai_cpu_sme(), ) +kai_c_library( + name = "lhs_pack_8x4_f32_bf16_neon", + srcs = ["pack/kai_lhs_pack_8x4_f32_bf16_neon.c"], + hdrs = ["pack/kai_lhs_pack_8x4_f32_bf16_neon.h"], + cpu_uarch = kai_cpu_bf16(), +) + kai_c_library( name = "rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon", srcs = ["pack/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.c"], @@ -166,6 +190,13 @@ kai_c_library( cpu_uarch = kai_cpu_fp16(), ) +kai_c_library( + name = "matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12", + srcs = ["pack/kai_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12.c"], + hdrs = ["pack/kai_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12.h"], + cpu_uarch = kai_cpu_bf16(), +) + kai_c_library( name = "rhs_pack_kxn_f32pbiasf32_f32_f32_neon", srcs = ["pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.c"], diff --git a/kai/ukernels/matmul/matmul_clamp_bf16_bf16_f32p/kai_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla.c b/kai/ukernels/matmul/matmul_clamp_bf16_bf16_f32p/kai_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla.c new file mode 100644 index 00000000..8e845095 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_bf16_bf16_f32p/kai_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla.c @@ -0,0 +1,596 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) +#error This file must be compiled for AArch64, FEAT_BF16. +#else // Architectural features check. + +#include +#include +#include +#include +#include + +typedef bfloat16_t bfloat16; + +#include "kai/kai_common.h" +#include "kai_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla.h" + +static const size_t kai_mr = 8; +static const size_t kai_nr = 12; +static const size_t kai_kr = 4; +static const size_t kai_sr = 1; + +size_t kai_get_m_step_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla(void) { + return kai_mr; +} + +size_t kai_get_n_step_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla(void) { + return kai_nr; +} + +size_t kai_get_mr_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla(void) { + return kai_sr; +} + +size_t kai_get_lhs_offset_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla(size_t m_idx, size_t stride) { + KAI_ASSUME(m_idx % kai_mr == 0); + + return m_idx * stride; +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla(size_t n_idx, size_t k) { + KAI_ASSUME(n_idx % kai_nr == 0); + + return n_idx / kai_nr * (kai_nr * sizeof(float) + kai_nr * k * sizeof(bfloat16)); +} + +size_t kai_get_dst_offset_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla(size_t m_idx, size_t n_idx, size_t stride) { + KAI_ASSUME(m_idx % kai_mr == 0); + KAI_ASSUME(n_idx % kai_nr == 0); + + return m_idx * stride + n_idx * sizeof(float); +} + +size_t kai_get_dst_size_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla(size_t m, size_t n) { + return m * n * sizeof(float); +} + +void kai_run_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla( + size_t m, size_t n, size_t k, // + const void* lhs_packed, size_t lhs_stride, // + const void* rhs_packed, // + void* dst, size_t dst_stride_row, size_t dst_stride_col, // + float clamp_min, float clamp_max) { + KAI_ASSERT(dst_stride_col == sizeof(float)); + + KAI_UNUSED(lhs_stride); + + const void *Apanel = lhs_packed; + // const void *Bpanel = rhs_packed; + void *Cpanel = dst; + size_t ldc = dst_stride_row / sizeof(float); + + size_t M = m; + + typedef struct { + float maxval; + float minval; + unsigned int num_strings; + const unsigned int* string_lengths; + size_t N; + size_t K; + const void* Bpanel; + size_t output_offset; + size_t input_initial_col; + size_t input_offset; + void* output_ptr; + const void* bias; + } KernelArgs; + + KernelArgs ka; + + unsigned int string_length = k; + ka.num_strings = 1; + ka.string_lengths = &string_length; + ka.N = n; + ka.K = kai_roundup(k, 4) / 4 - 1; + + ka.Bpanel = rhs_packed; + ka.bias = NULL; + + // Direct input. + // const void* input_ptr = lhs; + // ka.input_offset = lhs_stride / sizeof(bfloat16); + ka.input_initial_col = 0; + + // Direct output. + ka.output_ptr = dst; + // ka.output_offset = dst_stride_row / sizeof(float); + + // Clamping output. + ka.maxval = clamp_max; + ka.minval = clamp_min; + + __asm__ __volatile__( + "1:" // Height loop + "add x11, %x[Cpanel], %x[ldc], LSL #2\n" + "add x10, %x[Cpanel], %x[ldc], LSL #1\n" + "add x9, x11, %x[ldc], LSL #1\n" + "cmp %x[M], #0x8\n" + "add x28, %x[Cpanel], %x[ldc], LSL #3\n" + "add x27, %x[Cpanel], %x[ldc]\n" + "add x26, x10, %x[ldc]\n" + "add x25, x11, %x[ldc]\n" + "add x24, x9, %x[ldc]\n" + "bge 2f\n" + "cmp %x[M], #0x2\n" + "mov x24, %x[Cpanel]\n" + "csel x27, x27, %x[Cpanel], GE\n" + "csel x10, x10, %x[Cpanel], GT\n" + "cmp %x[M], #0x4\n" + "csel x26, x26, %x[Cpanel], GE\n" + "csel x11, x11, %x[Cpanel], GT\n" + "cmp %x[M], #0x6\n" + "csel x25, x25, %x[Cpanel], GE\n" + "csel x9, x9, %x[Cpanel], GT\n" + "2:" // all rows valid + "ldr x23, [%x[args_ptr], %[offsetof_N]]\n" + "ldr x22, [%x[args_ptr], %[offsetof_Bpanel]]\n" + "mov x21, %x[Apanel]\n" + "3:" // Width loop + "ldr q4, [x22, #0x0]\n" + "ldr q5, [x22, #0x10]\n" + "mov %x[Apanel], x21\n" + "ldr q6, [x22, #0x20]\n" + "ldr x20, [%x[args_ptr], %[offsetof_K]]\n" + "add x22, x22, #0x30\n" + "ldr q7, [x22, #0x0]\n" + "ldr q0, [%x[Apanel], #0x0]\n" + "ldr q1, [%x[Apanel], #0x10]\n" + "zip1 v8.2d, v4.2d, v4.2d\n" + "ldr q2, [%x[Apanel], #0x20]\n" + "zip2 v11.2d, v4.2d, v4.2d\n" + "ldr q4, [x22, #0x10]\n" + "zip1 v9.2d, v5.2d, v5.2d\n" + "zip2 v12.2d, v5.2d, v5.2d\n" + "cmp x20, #0x2\n" + "zip1 v10.2d, v6.2d, v6.2d\n" + "zip2 v13.2d, v6.2d, v6.2d\n" + "prfm pldl1keep, [%x[Apanel], #0x0]\n" + "mov v14.16b, v8.16b\n" + "mov v17.16b, v11.16b\n" + "prfm pldl1keep, [x22, #0x0]\n" + "mov v15.16b, v9.16b\n" + "mov v18.16b, v12.16b\n" + "prfm pldl1keep, [x22, #0x40]\n" + "mov v16.16b, v10.16b\n" + "mov v19.16b, v13.16b\n" + "prfm pldl1keep, [%x[Apanel], #0x40]\n" + "mov v20.16b, v8.16b\n" + "mov v21.16b, v9.16b\n" + "prfm pldl1keep, [x22, #0x80]\n" + "mov v22.16b, v10.16b\n" + "mov v23.16b, v11.16b\n" + "prfm pldl1keep, [%x[Apanel], #0x80]\n" + "mov v24.16b, v12.16b\n" + "mov v25.16b, v13.16b\n" + "prfm pldl1keep, [x22, #0xc0]\n" + "mov v26.16b, v8.16b\n" + "mov v27.16b, v9.16b\n" + "prfm pldl1keep, [x22, #0x100]\n" + "mov v28.16b, v10.16b\n" + "mov v29.16b, v11.16b\n" + "prfm pldl1keep, [%x[Apanel], #0xc0]\n" + "mov v30.16b, v12.16b\n" + "mov v31.16b, v13.16b\n" + "prfm pldl1keep, [x22, #0x140]\n" + "add x22, x22, #0x20\n" + "add %x[Apanel], %x[Apanel], #0x30\n" + "blt 5f\n" + "4:" // main loop head + "ldr q3, [%x[Apanel], #0x0]\n" + "ldr q5, [x22, #0x0]\n" + ".inst 0x6e47ec08 // bfmmla v8.4s, v0.8h, v7.8h\n" + "ldr q6, [x22, #0x10]\n" + ".inst 0x6e44ec0b // bfmmla v11.4s, v0.8h, v4.8h\n" + ".inst 0x6e47ec2e // bfmmla v14.4s, v1.8h, v7.8h\n" + ".inst 0x6e44ec31 // bfmmla v17.4s, v1.8h, v4.8h\n" + ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n" + "sub x20, x20, #0x2\n" + ".inst 0x6e44ec57 // bfmmla v23.4s, v2.8h, v4.8h\n" + ".inst 0x6e47ec7a // bfmmla v26.4s, v3.8h, v7.8h\n" + "ldr q7, [x22, #0x20]\n" + ".inst 0x6e44ec7d // bfmmla v29.4s, v3.8h, v4.8h\n" + "ldr q4, [x22, #0x30]\n" + ".inst 0x6e45ec09 // bfmmla v9.4s, v0.8h, v5.8h\n" + ".inst 0x6e46ec0c // bfmmla v12.4s, v0.8h, v6.8h\n" + ".inst 0x6e45ec2f // bfmmla v15.4s, v1.8h, v5.8h\n" + "cmp x20, #0x2\n" + ".inst 0x6e46ec32 // bfmmla v18.4s, v1.8h, v6.8h\n" + ".inst 0x6e45ec55 // bfmmla v21.4s, v2.8h, v5.8h\n" + "prfm pldl1keep, [%x[Apanel], #0x100]\n" + ".inst 0x6e46ec58 // bfmmla v24.4s, v2.8h, v6.8h\n" + ".inst 0x6e45ec7b // bfmmla v27.4s, v3.8h, v5.8h\n" + "ldr q5, [x22, #0x40]\n" + ".inst 0x6e46ec7e // bfmmla v30.4s, v3.8h, v6.8h\n" + "ldr q6, [x22, #0x50]\n" + ".inst 0x6e47ec0a // bfmmla v10.4s, v0.8h, v7.8h\n" + ".inst 0x6e44ec0d // bfmmla v13.4s, v0.8h, v4.8h\n" + "ldr q0, [%x[Apanel], #0x10]\n" + ".inst 0x6e47ec30 // bfmmla v16.4s, v1.8h, v7.8h\n" + ".inst 0x6e44ec33 // bfmmla v19.4s, v1.8h, v4.8h\n" + "ldr q1, [%x[Apanel], #0x20]\n" + ".inst 0x6e47ec56 // bfmmla v22.4s, v2.8h, v7.8h\n" + ".inst 0x6e44ec59 // bfmmla v25.4s, v2.8h, v4.8h\n" + "ldr q2, [%x[Apanel], #0x30]\n" + ".inst 0x6e47ec7c // bfmmla v28.4s, v3.8h, v7.8h\n" + "ldr q7, [x22, #0x60]\n" + ".inst 0x6e44ec7f // bfmmla v31.4s, v3.8h, v4.8h\n" + "ldr q3, [%x[Apanel], #0x40]\n" + "ldr q4, [x22, #0x70]\n" + ".inst 0x6e45ec08 // bfmmla v8.4s, v0.8h, v5.8h\n" + ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n" + ".inst 0x6e45ec2e // bfmmla v14.4s, v1.8h, v5.8h\n" + ".inst 0x6e46ec31 // bfmmla v17.4s, v1.8h, v6.8h\n" + "prfm pldl1keep, [x22, #0x180]\n" + ".inst 0x6e45ec54 // bfmmla v20.4s, v2.8h, v5.8h\n" + ".inst 0x6e46ec57 // bfmmla v23.4s, v2.8h, v6.8h\n" + "prfm pldl1keep, [x22, #0x1c0]\n" + ".inst 0x6e45ec7a // bfmmla v26.4s, v3.8h, v5.8h\n" + "ldr q5, [x22, #0x80]\n" + ".inst 0x6e46ec7d // bfmmla v29.4s, v3.8h, v6.8h\n" + "ldr q6, [x22, #0x90]\n" + "prfm pldl1keep, [%x[Apanel], #0x140]\n" + ".inst 0x6e47ec09 // bfmmla v9.4s, v0.8h, v7.8h\n" + "prfm pldl1keep, [x22, #0x200]\n" + ".inst 0x6e44ec0c // bfmmla v12.4s, v0.8h, v4.8h\n" + ".inst 0x6e47ec2f // bfmmla v15.4s, v1.8h, v7.8h\n" + ".inst 0x6e44ec32 // bfmmla v18.4s, v1.8h, v4.8h\n" + ".inst 0x6e47ec55 // bfmmla v21.4s, v2.8h, v7.8h\n" + ".inst 0x6e44ec58 // bfmmla v24.4s, v2.8h, v4.8h\n" + ".inst 0x6e47ec7b // bfmmla v27.4s, v3.8h, v7.8h\n" + "ldr q7, [x22, #0xa0]\n" + ".inst 0x6e44ec7e // bfmmla v30.4s, v3.8h, v4.8h\n" + "ldr q4, [x22, #0xb0]\n" + ".inst 0x6e45ec0a // bfmmla v10.4s, v0.8h, v5.8h\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + "ldr q0, [%x[Apanel], #0x50]\n" + ".inst 0x6e45ec30 // bfmmla v16.4s, v1.8h, v5.8h\n" + ".inst 0x6e46ec33 // bfmmla v19.4s, v1.8h, v6.8h\n" + "ldr q1, [%x[Apanel], #0x60]\n" + ".inst 0x6e45ec56 // bfmmla v22.4s, v2.8h, v5.8h\n" + ".inst 0x6e46ec59 // bfmmla v25.4s, v2.8h, v6.8h\n" + "ldr q2, [%x[Apanel], #0x70]\n" + ".inst 0x6e45ec7c // bfmmla v28.4s, v3.8h, v5.8h\n" + ".inst 0x6e46ec7f // bfmmla v31.4s, v3.8h, v6.8h\n" + "add %x[Apanel], %x[Apanel], #0x80\n" + "add x22, x22, #0xc0\n" + "bge 4b\n" + "5:" // main loop skip + "ldr q3, [%x[Apanel], #0x0]\n" + "ldr q5, [x22, #0x0]\n" + ".inst 0x6e47ec08 // bfmmla v8.4s, v0.8h, v7.8h\n" + "ldr q6, [x22, #0x10]\n" + ".inst 0x6e44ec0b // bfmmla v11.4s, v0.8h, v4.8h\n" + ".inst 0x6e47ec2e // bfmmla v14.4s, v1.8h, v7.8h\n" + ".inst 0x6e44ec31 // bfmmla v17.4s, v1.8h, v4.8h\n" + ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n" + "add %x[Apanel], %x[Apanel], #0x10\n" + ".inst 0x6e44ec57 // bfmmla v23.4s, v2.8h, v4.8h\n" + ".inst 0x6e47ec7a // bfmmla v26.4s, v3.8h, v7.8h\n" + "ldr q7, [x22, #0x20]\n" + ".inst 0x6e44ec7d // bfmmla v29.4s, v3.8h, v4.8h\n" + "ldr q4, [x22, #0x30]\n" + ".inst 0x6e45ec09 // bfmmla v9.4s, v0.8h, v5.8h\n" + ".inst 0x6e46ec0c // bfmmla v12.4s, v0.8h, v6.8h\n" + ".inst 0x6e45ec2f // bfmmla v15.4s, v1.8h, v5.8h\n" + "add x22, x22, #0x40\n" + ".inst 0x6e46ec32 // bfmmla v18.4s, v1.8h, v6.8h\n" + ".inst 0x6e45ec55 // bfmmla v21.4s, v2.8h, v5.8h\n" + ".inst 0x6e46ec58 // bfmmla v24.4s, v2.8h, v6.8h\n" + ".inst 0x6e45ec7b // bfmmla v27.4s, v3.8h, v5.8h\n" + ".inst 0x6e46ec7e // bfmmla v30.4s, v3.8h, v6.8h\n" + ".inst 0x6e47ec0a // bfmmla v10.4s, v0.8h, v7.8h\n" + ".inst 0x6e44ec0d // bfmmla v13.4s, v0.8h, v4.8h\n" + ".inst 0x6e47ec30 // bfmmla v16.4s, v1.8h, v7.8h\n" + ".inst 0x6e44ec33 // bfmmla v19.4s, v1.8h, v4.8h\n" + ".inst 0x6e47ec56 // bfmmla v22.4s, v2.8h, v7.8h\n" + ".inst 0x6e44ec59 // bfmmla v25.4s, v2.8h, v4.8h\n" + ".inst 0x6e47ec7c // bfmmla v28.4s, v3.8h, v7.8h\n" + ".inst 0x6e44ec7f // bfmmla v31.4s, v3.8h, v4.8h\n" + "cbz x20, 6f\n" + "ldr q5, [x22, #0x0]\n" + "ldr q0, [%x[Apanel], #0x0]\n" + "ldr q1, [%x[Apanel], #0x10]\n" + "ldr q6, [x22, #0x10]\n" + "ldr q2, [%x[Apanel], #0x20]\n" + "ldr q3, [%x[Apanel], #0x30]\n" + "add %x[Apanel], %x[Apanel], #0x40\n" + "ldr q7, [x22, #0x20]\n" + "ldr q4, [x22, #0x30]\n" + ".inst 0x6e45ec08 // bfmmla v8.4s, v0.8h, v5.8h\n" + ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n" + ".inst 0x6e45ec2e // bfmmla v14.4s, v1.8h, v5.8h\n" + ".inst 0x6e46ec31 // bfmmla v17.4s, v1.8h, v6.8h\n" + ".inst 0x6e45ec54 // bfmmla v20.4s, v2.8h, v5.8h\n" + ".inst 0x6e46ec57 // bfmmla v23.4s, v2.8h, v6.8h\n" + ".inst 0x6e45ec7a // bfmmla v26.4s, v3.8h, v5.8h\n" + "ldr q5, [x22, #0x40]\n" + ".inst 0x6e46ec7d // bfmmla v29.4s, v3.8h, v6.8h\n" + "ldr q6, [x22, #0x50]\n" + ".inst 0x6e47ec09 // bfmmla v9.4s, v0.8h, v7.8h\n" + ".inst 0x6e44ec0c // bfmmla v12.4s, v0.8h, v4.8h\n" + ".inst 0x6e47ec2f // bfmmla v15.4s, v1.8h, v7.8h\n" + "add x22, x22, #0x60\n" + ".inst 0x6e44ec32 // bfmmla v18.4s, v1.8h, v4.8h\n" + ".inst 0x6e47ec55 // bfmmla v21.4s, v2.8h, v7.8h\n" + ".inst 0x6e44ec58 // bfmmla v24.4s, v2.8h, v4.8h\n" + ".inst 0x6e47ec7b // bfmmla v27.4s, v3.8h, v7.8h\n" + ".inst 0x6e44ec7e // bfmmla v30.4s, v3.8h, v4.8h\n" + ".inst 0x6e45ec0a // bfmmla v10.4s, v0.8h, v5.8h\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + ".inst 0x6e45ec30 // bfmmla v16.4s, v1.8h, v5.8h\n" + ".inst 0x6e46ec33 // bfmmla v19.4s, v1.8h, v6.8h\n" + ".inst 0x6e45ec56 // bfmmla v22.4s, v2.8h, v5.8h\n" + ".inst 0x6e46ec59 // bfmmla v25.4s, v2.8h, v6.8h\n" + ".inst 0x6e45ec7c // bfmmla v28.4s, v3.8h, v5.8h\n" + ".inst 0x6e46ec7f // bfmmla v31.4s, v3.8h, v6.8h\n" + "6:" // multiply loop done + "add x20, %x[args_ptr], %[offset_max]\n" + "uzp1 v7.2d, v8.2d, v11.2d\n" + "uzp2 v8.2d, v8.2d, v11.2d\n" + "ld1r { v1.4s }, [x20]\n" + "uzp1 v11.2d, v9.2d, v12.2d\n" + "uzp2 v9.2d, v9.2d, v12.2d\n" + "uzp1 v12.2d, v10.2d, v13.2d\n" + "uzp2 v10.2d, v10.2d, v13.2d\n" + "add x20, %x[args_ptr], %[offset_min]\n" + "ld1r { v0.4s }, [x20]\n" + "uzp1 v13.2d, v14.2d, v17.2d\n" + "uzp2 v14.2d, v14.2d, v17.2d\n" + "uzp1 v17.2d, v15.2d, v18.2d\n" + "uzp2 v15.2d, v15.2d, v18.2d\n" + "cmp x23, #0xc\n" + "uzp1 v18.2d, v16.2d, v19.2d\n" + "uzp2 v16.2d, v16.2d, v19.2d\n" + "uzp1 v19.2d, v20.2d, v23.2d\n" + "uzp2 v20.2d, v20.2d, v23.2d\n" + "uzp1 v23.2d, v21.2d, v24.2d\n" + "uzp2 v21.2d, v21.2d, v24.2d\n" + "uzp1 v24.2d, v22.2d, v25.2d\n" + "uzp2 v22.2d, v22.2d, v25.2d\n" + "uzp1 v25.2d, v26.2d, v29.2d\n" + "uzp2 v26.2d, v26.2d, v29.2d\n" + "uzp1 v29.2d, v27.2d, v30.2d\n" + "uzp2 v27.2d, v27.2d, v30.2d\n" + "uzp1 v30.2d, v28.2d, v31.2d\n" + "uzp2 v28.2d, v28.2d, v31.2d\n" + "fmin v7.4s, v7.4s, v1.4s\n" + "fmin v11.4s, v11.4s, v1.4s\n" + "fmin v12.4s, v12.4s, v1.4s\n" + "fmin v8.4s, v8.4s, v1.4s\n" + "fmin v9.4s, v9.4s, v1.4s\n" + "fmin v10.4s, v10.4s, v1.4s\n" + "fmin v13.4s, v13.4s, v1.4s\n" + "fmin v17.4s, v17.4s, v1.4s\n" + "fmin v18.4s, v18.4s, v1.4s\n" + "fmin v14.4s, v14.4s, v1.4s\n" + "fmin v15.4s, v15.4s, v1.4s\n" + "fmin v16.4s, v16.4s, v1.4s\n" + "fmin v19.4s, v19.4s, v1.4s\n" + "fmin v23.4s, v23.4s, v1.4s\n" + "fmin v24.4s, v24.4s, v1.4s\n" + "fmin v20.4s, v20.4s, v1.4s\n" + "fmin v21.4s, v21.4s, v1.4s\n" + "fmin v22.4s, v22.4s, v1.4s\n" + "fmin v25.4s, v25.4s, v1.4s\n" + "fmin v29.4s, v29.4s, v1.4s\n" + "fmin v30.4s, v30.4s, v1.4s\n" + "fmin v26.4s, v26.4s, v1.4s\n" + "fmin v27.4s, v27.4s, v1.4s\n" + "fmin v28.4s, v28.4s, v1.4s\n" + "fmax v7.4s, v7.4s, v0.4s\n" + "fmax v11.4s, v11.4s, v0.4s\n" + "fmax v12.4s, v12.4s, v0.4s\n" + "fmax v8.4s, v8.4s, v0.4s\n" + "fmax v9.4s, v9.4s, v0.4s\n" + "fmax v10.4s, v10.4s, v0.4s\n" + "fmax v13.4s, v13.4s, v0.4s\n" + "fmax v17.4s, v17.4s, v0.4s\n" + "fmax v18.4s, v18.4s, v0.4s\n" + "fmax v14.4s, v14.4s, v0.4s\n" + "fmax v15.4s, v15.4s, v0.4s\n" + "fmax v16.4s, v16.4s, v0.4s\n" + "fmax v19.4s, v19.4s, v0.4s\n" + "fmax v23.4s, v23.4s, v0.4s\n" + "fmax v24.4s, v24.4s, v0.4s\n" + "fmax v20.4s, v20.4s, v0.4s\n" + "fmax v21.4s, v21.4s, v0.4s\n" + "fmax v22.4s, v22.4s, v0.4s\n" + "fmax v25.4s, v25.4s, v0.4s\n" + "fmax v29.4s, v29.4s, v0.4s\n" + "fmax v30.4s, v30.4s, v0.4s\n" + "fmax v26.4s, v26.4s, v0.4s\n" + "fmax v27.4s, v27.4s, v0.4s\n" + "fmax v28.4s, v28.4s, v0.4s\n" + "blt 7f\n" + "str q26, [x24, #0x0]\n" + "str q27, [x24, #0x10]\n" + "str q28, [x24, #0x20]\n" + "add x24, x24, #0x30\n" + "str q25, [x9, #0x0]\n" + "str q29, [x9, #0x10]\n" + "str q30, [x9, #0x20]\n" + "add x9, x9, #0x30\n" + "str q20, [x25, #0x0]\n" + "str q21, [x25, #0x10]\n" + "str q22, [x25, #0x20]\n" + "add x25, x25, #0x30\n" + "str q19, [x11, #0x0]\n" + "str q23, [x11, #0x10]\n" + "str q24, [x11, #0x20]\n" + "add x11, x11, #0x30\n" + "str q14, [x26, #0x0]\n" + "str q15, [x26, #0x10]\n" + "str q16, [x26, #0x20]\n" + "add x26, x26, #0x30\n" + "str q13, [x10, #0x0]\n" + "str q17, [x10, #0x10]\n" + "str q18, [x10, #0x20]\n" + "add x10, x10, #0x30\n" + "str q8, [x27, #0x0]\n" + "str q9, [x27, #0x10]\n" + "str q10, [x27, #0x20]\n" + "add x27, x27, #0x30\n" + "str q7, [%x[Cpanel], #0x0]\n" + "str q11, [%x[Cpanel], #0x10]\n" + "str q12, [%x[Cpanel], #0x20]\n" + "add %x[Cpanel], %x[Cpanel], #0x30\n" + "b 14f\n" + "7:" // partial output + "tbz x23, #3, 9f\n" + "st1 { v26.4s }, [x24], #0x10\n" + "st1 { v27.4s }, [x24], #0x10\n" + "st1 { v25.4s }, [x9], #0x10\n" + "st1 { v29.4s }, [x9], #0x10\n" + "st1 { v20.4s }, [x25], #0x10\n" + "st1 { v21.4s }, [x25], #0x10\n" + "st1 { v19.4s }, [x11], #0x10\n" + "st1 { v23.4s }, [x11], #0x10\n" + "st1 { v14.4s }, [x26], #0x10\n" + "st1 { v15.4s }, [x26], #0x10\n" + "st1 { v13.4s }, [x10], #0x10\n" + "st1 { v17.4s }, [x10], #0x10\n" + "st1 { v8.4s }, [x27], #0x10\n" + "st1 { v9.4s }, [x27], #0x10\n" + "st1 { v7.4s }, [%x[Cpanel]], #0x10\n" + "st1 { v11.4s }, [%x[Cpanel]], #0x10\n" + "tbz x23, #1, 8f\n" + "str d28, [x24], #0x8\n" + "str d30, [x9], #0x8\n" + "str d22, [x25], #0x8\n" + "str d24, [x11], #0x8\n" + "str d16, [x26], #0x8\n" + "str d18, [x10], #0x8\n" + "str d10, [x27], #0x8\n" + "str d12, [%x[Cpanel]], #0x8\n" + "tbz x23, #0, 13f\n" + "st1 { v28.s }[2], [x24]\n" + "st1 { v30.s }[2], [x9]\n" + "st1 { v22.s }[2], [x25]\n" + "st1 { v24.s }[2], [x11]\n" + "st1 { v16.s }[2], [x26]\n" + "st1 { v18.s }[2], [x10]\n" + "st1 { v10.s }[2], [x27]\n" + "st1 { v12.s }[2], [%x[Cpanel]]\n" + "b 13f\n" + "8:" // partial result store: partial_1_8 + "tbz x23, #0, 13f\n" + "str s28, [x24, #0x0]\n" + "str s30, [x9, #0x0]\n" + "str s22, [x25, #0x0]\n" + "str s24, [x11, #0x0]\n" + "str s16, [x26, #0x0]\n" + "str s18, [x10, #0x0]\n" + "str s10, [x27, #0x0]\n" + "str s12, [%x[Cpanel], #0x0]\n" + "b 13f\n" + "9:" // partial result store: partial_4_0 + "tbz x23, #2, 11f\n" + "st1 { v26.4s }, [x24], #0x10\n" + "st1 { v25.4s }, [x9], #0x10\n" + "st1 { v20.4s }, [x25], #0x10\n" + "st1 { v19.4s }, [x11], #0x10\n" + "st1 { v14.4s }, [x26], #0x10\n" + "st1 { v13.4s }, [x10], #0x10\n" + "st1 { v8.4s }, [x27], #0x10\n" + "st1 { v7.4s }, [%x[Cpanel]], #0x10\n" + "tbz x23, #1, 10f\n" + "str d27, [x24], #0x8\n" + "str d29, [x9], #0x8\n" + "str d21, [x25], #0x8\n" + "str d23, [x11], #0x8\n" + "str d15, [x26], #0x8\n" + "str d17, [x10], #0x8\n" + "str d9, [x27], #0x8\n" + "str d11, [%x[Cpanel]], #0x8\n" + "tbz x23, #0, 13f\n" + "st1 { v27.s }[2], [x24]\n" + "st1 { v29.s }[2], [x9]\n" + "st1 { v21.s }[2], [x25]\n" + "st1 { v23.s }[2], [x11]\n" + "st1 { v15.s }[2], [x26]\n" + "st1 { v17.s }[2], [x10]\n" + "st1 { v9.s }[2], [x27]\n" + "st1 { v11.s }[2], [%x[Cpanel]]\n" + "b 13f\n" + "10:" // partial result store: partial_1_4 + "tbz x23, #0, 13f\n" + "str s27, [x24, #0x0]\n" + "str s29, [x9, #0x0]\n" + "str s21, [x25, #0x0]\n" + "str s23, [x11, #0x0]\n" + "str s15, [x26, #0x0]\n" + "str s17, [x10, #0x0]\n" + "str s9, [x27, #0x0]\n" + "str s11, [%x[Cpanel], #0x0]\n" + "b 13f\n" + "11:" // partial result store: partial_2_0 + "tbz x23, #1, 12f\n" + "str d26, [x24], #0x8\n" + "str d25, [x9], #0x8\n" + "str d20, [x25], #0x8\n" + "str d19, [x11], #0x8\n" + "str d14, [x26], #0x8\n" + "str d13, [x10], #0x8\n" + "str d8, [x27], #0x8\n" + "str d7, [%x[Cpanel]], #0x8\n" + "tbz x23, #0, 13f\n" + "st1 { v26.s }[2], [x24]\n" + "st1 { v25.s }[2], [x9]\n" + "st1 { v20.s }[2], [x25]\n" + "st1 { v19.s }[2], [x11]\n" + "st1 { v14.s }[2], [x26]\n" + "st1 { v13.s }[2], [x10]\n" + "st1 { v8.s }[2], [x27]\n" + "st1 { v7.s }[2], [%x[Cpanel]]\n" + "b 13f\n" + "12:" // partial result store: partial_1_0 + "str s26, [x24, #0x0]\n" + "str s25, [x9, #0x0]\n" + "str s20, [x25, #0x0]\n" + "str s19, [x11, #0x0]\n" + "str s14, [x26, #0x0]\n" + "str s13, [x10, #0x0]\n" + "str s8, [x27, #0x0]\n" + "str s7, [%x[Cpanel], #0x0]\n" + "13:" // partial result store: Done + "14:" // store done + "subs x23, x23, #0xc\n" + "bgt 3b\n" + "subs %x[M], %x[M], #0x8\n" + "mov %x[Cpanel], x28\n" + "bgt 1b\n" + : [Apanel] "+&r" (Apanel), [Cpanel] "+&r" (Cpanel), [M] "+&r" (M) + : [args_ptr] "r" (&ka), [ldc] "r" (ldc * sizeof(float)), [offset_max] "I" (offsetof(KernelArgs, maxval)), [offset_min] "I" (offsetof(KernelArgs, minval)), [offsetof_Bpanel] "I" (offsetof(KernelArgs, Bpanel)), [offsetof_K] "I" (offsetof(KernelArgs, K)), [offsetof_N] "I" (offsetof(KernelArgs, N)) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x11", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" + ); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_bf16_bf16_f32p/kai_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla.h b/kai/ukernels/matmul/matmul_clamp_bf16_bf16_f32p/kai_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla.h new file mode 100644 index 00000000..56425303 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_bf16_bf16_f32p/kai_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla.h @@ -0,0 +1,122 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) +#error This file must be compiled for AArch64, FEAT_BF16. +#else // Architectural features check. + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// -------------------------------------------------- + +/// Gets m step value. +/// +/// The starting row index must be divisible by `m_step`. +/// +/// @return The m step value. +size_t kai_get_m_step_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla(void); + +/// Gets n step value. +/// +/// The starting column index must be divisible by `n_step`. +/// +/// @return The n step value. +size_t kai_get_n_step_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla(void); + +size_t kai_get_mr_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla(void); + +/// Gets nr value. +/// +/// This is the packing parameter which must be used to pack the RHS matrix. +/// +/// @return The nr value. +size_t kai_get_nr_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla(void); + +/// Gets kr value. +/// +/// This is the packing parameter which must be used to pack the RHS matrix. +/// +/// @return The kr value. +size_t kai_get_kr_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla(void); + +/// Gets sr value. +/// +/// This is the packing parameter which must be used to pack the RHS matrix. +/// +/// @return The sr value. +size_t kai_get_sr_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla(void); + +/// Gets the offset in bytes to the data element in the LHS matrix buffer. +/// +/// @param[in] m_idx Row index. +/// @param[in] stride Row stride in bytes. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_lhs_offset_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla(size_t m_idx, size_t stride); + +/// Gets the offset in bytes to the data element in the packed RHS matrix buffer. +/// +/// @param[in] n_idx Row index. +/// @param[in] k Number of columns. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_packed_offset_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla(size_t n_idx, size_t k); + +/// Gets the offset in bytes to the data element in the destination matrix buffer. +/// +/// @param[in] m_idx Row index. +/// @param[in] n_idx Column index. +/// @param[in] stride Row stride in bytes. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_dst_offset_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla(size_t m_idx, size_t n_idx, size_t stride); + +/// Gets the size in bytes of the destination matrix buffer. +/// +/// @param[in] m Number of rows. +/// @param[in] n Number of columns. +/// +/// @return The size in bytes of the destination matrix buffer. +size_t kai_get_dst_size_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla(size_t m, size_t n); + +/// Runs the matrix multiplication microkernel followed by a clamp operation. +/// +/// The pointer of each buffers (LHS, packed RHS and output) needs to be added with offset +/// calculated using the following functions: +/// +/// * LHS: @ref kai_get_lhs_offset_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla. +/// * Packed RHS: @ref kai_get_rhs_packed_offset_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla. +/// * Output: @ref kai_get_dst_offset_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla. +/// +/// @param[in] m Number of output rows to be computed. +/// @param[in] n Number of output columns to be computed. +/// @param[in] k Common dimension of the LHS and RHS operand. +/// @param[in] lhs LHS matrix buffer. +/// @param[in] lhs_stride Row stride in bytes of the LHS matrix. +/// @param[in] rhs_packed Packed RHS buffer. +/// @param[out] dst Output matrix buffer. +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(float) +/// @param[in] clamp_min Minimum value to clamp the final result. +/// @param[in] clamp_max Maximum value to clamp the final result. +void kai_run_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla( + size_t m, size_t n, size_t k, // + const void* lhs_packed, size_t lhs_stride, // + const void* rhs_packed, // + void* dst, size_t dst_stride_row, size_t dst_stride_col, // + float clamp_min, float clamp_max); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_bf16_bf16_f32p/matmul_clamp_bf16_bf16_f32p_interface.h b/kai/ukernels/matmul/matmul_clamp_bf16_bf16_f32p/matmul_clamp_bf16_bf16_f32p_interface.h new file mode 100644 index 00000000..28a715b4 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_bf16_bf16_f32p/matmul_clamp_bf16_bf16_f32p_interface.h @@ -0,0 +1,57 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) +#error This file must be compiled for AArch64, FEAT_BF16. +#else // Architectural features check. + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// All micro-kernels variants of the same type share the same interfaces +// In this case, the micro-kernel type is: matmul_clamp_bf16_bf16_f32p + +/// Micro-kernel helper functions ("get" methods) +typedef size_t (*kai_matmul_clamp_bf16_bf16_f32p_get_m_step_func_t)(void); +typedef size_t (*kai_matmul_clamp_bf16_bf16_f32p_get_n_step_func_t)(void); +typedef size_t (*kai_matmul_clamp_bf16_bf16_f32p_get_mr_func_t)(void); +typedef size_t (*kai_matmul_clamp_bf16_bf16_f32p_get_nr_func_t)(void); +typedef size_t (*kai_matmul_clamp_bf16_bf16_f32p_get_kr_func_t)(void); +typedef size_t (*kai_matmul_clamp_bf16_bf16_f32p_get_sr_func_t)(void); +typedef size_t (*kai_matmul_clamp_bf16_bf16_f32p_get_lhs_offset_func_t)(size_t m_idx, size_t lhs_stride); +typedef size_t (*kai_matmul_clamp_bf16_bf16_f32p_get_rhs_packed_offset_func_t)(size_t n_idx, size_t k); +typedef size_t (*kai_matmul_clamp_bf16_bf16_f32p_get_dst_offset_func_t)(size_t m_idx, size_t n_idx, size_t dst_stride); +typedef size_t (*kai_matmul_clamp_bf16_bf16_f32p_get_dst_size_func_t)(size_t m, size_t n); + +/// Micro-kernel core function ("run" method) +typedef void (*kai_matmul_clamp_bf16_bf16_f32p_run_matmul_func_t)( + size_t m, size_t n, size_t k, const void* lhs, size_t lhs_stride, const void* rhs_packed, void* dst, + size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max); + +/// Micro-kernel interface +struct kai_matmul_clamp_bf16_bf16_f32p_ukernel { + kai_matmul_clamp_bf16_bf16_f32p_get_m_step_func_t get_m_step; + kai_matmul_clamp_bf16_bf16_f32p_get_n_step_func_t get_n_step; + kai_matmul_clamp_bf16_bf16_f32p_get_mr_func_t get_mr; + kai_matmul_clamp_bf16_bf16_f32p_get_nr_func_t get_nr; + kai_matmul_clamp_bf16_bf16_f32p_get_nr_func_t get_kr; + kai_matmul_clamp_bf16_bf16_f32p_get_sr_func_t get_sr; + kai_matmul_clamp_bf16_bf16_f32p_get_lhs_offset_func_t get_lhs_packed_offset; + kai_matmul_clamp_bf16_bf16_f32p_get_rhs_packed_offset_func_t get_rhs_packed_offset; + kai_matmul_clamp_bf16_bf16_f32p_get_dst_offset_func_t get_dst_offset; + kai_matmul_clamp_bf16_bf16_f32p_get_dst_size_func_t get_dst_size; + kai_matmul_clamp_bf16_bf16_f32p_run_matmul_func_t run_matmul; +}; + +#ifdef __cplusplus +} +#endif + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_lhs_pack_8x4_f32_bf16_neon.c b/kai/ukernels/matmul/pack/kai_lhs_pack_8x4_f32_bf16_neon.c new file mode 100644 index 00000000..af8557d0 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_pack_8x4_f32_bf16_neon.c @@ -0,0 +1,222 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) +#error This file must be compiled for AArch64, FEAT_BF16. +#else // Architectural features check. + +#include +#include +#include +#include "kai/kai_common.h" + +static const size_t kai_mr = 8; +static const size_t kai_kr = 4; +static const size_t kai_sr = 1; +static const size_t vec_len = 1; + +size_t kai_get_m_step_lhs_pack_8x4_f32_bf16_neon(size_t mr) { + KAI_ASSUME(mr == kai_mr * vec_len); + KAI_UNUSED(mr); + + return kai_mr * vec_len; +} + +size_t kai_get_lhs_offset_lhs_pack_8x4_f32_bf16_neon(size_t m_idx, size_t lhs_stride) { + KAI_ASSUME(m_idx % (kai_mr * vec_len) == 0); + + return m_idx * lhs_stride; +} + +size_t kai_get_lhs_packed_offset_lhs_pack_8x4_f32_bf16_neon(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) { + const size_t scaled_mr = kai_mr * vec_len; + KAI_ASSUME(m_idx % scaled_mr == 0); + KAI_ASSUME(mr == scaled_mr); + KAI_ASSUME(kr == kai_kr); + KAI_ASSUME(sr == kai_sr); + + KAI_UNUSED(mr); + KAI_UNUSED(kr); + KAI_UNUSED(sr); + + return m_idx * kai_roundup(k, kr) * sizeof(bfloat16_t); +} + +size_t kai_get_lhs_packed_size_lhs_pack_8x4_f32_bf16_neon(size_t m, size_t k, size_t mr, size_t kr, size_t sr) { + KAI_ASSUME(mr == kai_mr * vec_len); + KAI_ASSUME(kr == kai_kr); + KAI_ASSUME(sr == kai_sr); + + KAI_UNUSED(mr); + KAI_UNUSED(kr); + KAI_UNUSED(sr); + + return kai_roundup(m, kai_mr * vec_len) * kai_roundup(k, kai_kr) * sizeof(bfloat16_t); +} + +void kai_run_lhs_pack_8x4_f32_bf16_neon( + size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, + const void* lhs, size_t lhs_stride, void* lhs_packed +) +{ + KAI_ASSUME(mr == kai_mr * vec_len); + KAI_ASSUME(kr == kai_kr); + KAI_ASSUME(sr == kai_sr); + KAI_ASSUME(lhs != NULL); + KAI_ASSUME(lhs_packed != NULL); + + KAI_ASSUME(m_idx_start == 0); + + const size_t block_height = kai_mr * vec_len; + const size_t row_offset = 0; + + const void* in[block_height]; + + for (size_t block_y = 0; block_y < m; block_y += block_height) { + const size_t height = KAI_MIN(m - block_y, block_height); + void* out = (char*)lhs_packed + block_y * kai_roundup(k,kr) * sizeof(bfloat16_t); + size_t width = k; + + for (size_t y = 0; y < height; y++) { + in[y] = (char*)lhs + (block_y + y) * lhs_stride; + } + + __asm__ __volatile__( + "ldr x28, [%x[in], #0x0]\n" + "ldr x27, [%x[in], #0x8]\n" + "cmp %x[height], #0x8\n" + "ldr x26, [%x[in], #0x10]\n" + "ldr x25, [%x[in], #0x18]\n" + "ldr x24, [%x[in], #0x20]\n" + "ldr x23, [%x[in], #0x28]\n" + "ldr x22, [%x[in], #0x30]\n" + "ldr x21, [%x[in], #0x38]\n" + "add x28, x28, %x[row_offset], LSL #2\n" + "add x27, x27, %x[row_offset], LSL #2\n" + "add x26, x26, %x[row_offset], LSL #2\n" + "add x25, x25, %x[row_offset], LSL #2\n" + "add x24, x24, %x[row_offset], LSL #2\n" + "add x23, x23, %x[row_offset], LSL #2\n" + "add x22, x22, %x[row_offset], LSL #2\n" + "add x21, x21, %x[row_offset], LSL #2\n" + "beq 1f\n" + "cmp %x[height], #0x2\n" + "mov x21, x28\n" + "csel x27, x27, x28, GE\n" + "csel x26, x26, x28, GT\n" + "cmp %x[height], #0x4\n" + "csel x25, x25, x28, GE\n" + "csel x24, x24, x28, GT\n" + "cmp %x[height], #0x6\n" + "csel x23, x23, x28, GE\n" + "csel x22, x22, x28, GT\n" + "1:" // no_pointer_adj + "cmp %x[width], #0x4\n" + "prfm pldl1keep, [x28, #0x0]\n" + "prfm pldl1keep, [x27, #0x0]\n" + "prfm pldl1keep, [x26, #0x0]\n" + "prfm pldl1keep, [x25, #0x0]\n" + "prfm pldl1keep, [x24, #0x0]\n" + "prfm pldl1keep, [x23, #0x0]\n" + "prfm pldl1keep, [x22, #0x0]\n" + "prfm pldl1keep, [x21, #0x0]\n" + "prfm pldl1keep, [x28, #0x40]\n" + "prfm pldl1keep, [x27, #0x40]\n" + "prfm pldl1keep, [x26, #0x40]\n" + "prfm pldl1keep, [x25, #0x40]\n" + "prfm pldl1keep, [x24, #0x40]\n" + "prfm pldl1keep, [x23, #0x40]\n" + "prfm pldl1keep, [x22, #0x40]\n" + "prfm pldl1keep, [x21, #0x40]\n" + "blt 3f\n" + "2:" // Main loop head + "ldr q19, [x28], #0x10\n" + "ldr q18, [x26], #0x10\n" + "subs %x[width], %x[width], #0x4\n" + "ldr q17, [x24], #0x10\n" + "ldr q16, [x22], #0x10\n" + "cmp %x[width], #0x4\n" + "ldr q23, [x27], #0x10\n" + "ldr q22, [x25], #0x10\n" + "ldr q21, [x23], #0x10\n" + "ldr q20, [x21], #0x10\n" + ".inst 0x0ea16a73 // bfcvtn v19.4h, v19.4s\n" + ".inst 0x0ea16a52 // bfcvtn v18.4h, v18.4s\n" + ".inst 0x0ea16a31 // bfcvtn v17.4h, v17.4s\n" + ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n" + "prfm pldl1keep, [x28, #0x70]\n" + "prfm pldl1keep, [x27, #0x70]\n" + "prfm pldl1keep, [x26, #0x70]\n" + "prfm pldl1keep, [x25, #0x70]\n" + "prfm pldl1keep, [x24, #0x70]\n" + "prfm pldl1keep, [x23, #0x70]\n" + ".inst 0x4ea16af3 // bfcvtn2 v19.8h, v23.4s\n" + ".inst 0x4ea16ad2 // bfcvtn2 v18.8h, v22.4s\n" + "prfm pldl1keep, [x22, #0x70]\n" + "prfm pldl1keep, [x21, #0x70]\n" + ".inst 0x4ea16ab1 // bfcvtn2 v17.8h, v21.4s\n" + ".inst 0x4ea16a90 // bfcvtn2 v16.8h, v20.4s\n" + "str q19, [%x[out_ptr], #0x0]\n" + "str q18, [%x[out_ptr], #0x10]\n" + "str q17, [%x[out_ptr], #0x20]\n" + "str q16, [%x[out_ptr], #0x30]\n" + "add %x[out_ptr], %x[out_ptr], #0x40\n" + "bge 2b\n" + "3:" // Main loop skip + "cbz %x[width], 6f\n" + "tbz %x[width], #1, 4f\n" + "ldr d19, [x28], #0x8\n" + "ldr d23, [x27], #0x8\n" + "mov x20, #0x1\n" + "ldr d18, [x26], #0x8\n" + "ldr d22, [x25], #0x8\n" + "ldr d17, [x24], #0x8\n" + "ldr d21, [x23], #0x8\n" + "ldr d16, [x22], #0x8\n" + "ldr d20, [x21], #0x8\n" + "tbz %x[width], #0, 5f\n" + "ld1 { v19.s }[2], [x28]\n" + "ld1 { v23.s }[2], [x27]\n" + "ld1 { v18.s }[2], [x26]\n" + "ld1 { v22.s }[2], [x25]\n" + "ld1 { v17.s }[2], [x24]\n" + "ld1 { v21.s }[2], [x23]\n" + "ld1 { v16.s }[2], [x22]\n" + "ld1 { v20.s }[2], [x21]\n" + "b 5f\n" + "4:" // odd_loads_1_0 + "ldr s19, [x28, #0x0]\n" + "ldr s23, [x27, #0x0]\n" + "mov x20, #0x1\n" + "ldr s18, [x26, #0x0]\n" + "ldr s22, [x25, #0x0]\n" + "ldr s17, [x24, #0x0]\n" + "ldr s21, [x23, #0x0]\n" + "ldr s16, [x22, #0x0]\n" + "ldr s20, [x21, #0x0]\n" + "5:" // Odd load end + ".inst 0x0ea16a73 // bfcvtn v19.4h, v19.4s\n" + ".inst 0x0ea16a52 // bfcvtn v18.4h, v18.4s\n" + ".inst 0x0ea16a31 // bfcvtn v17.4h, v17.4s\n" + ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n" + ".inst 0x4ea16af3 // bfcvtn2 v19.8h, v23.4s\n" + ".inst 0x4ea16ad2 // bfcvtn2 v18.8h, v22.4s\n" + ".inst 0x4ea16ab1 // bfcvtn2 v17.8h, v21.4s\n" + ".inst 0x4ea16a90 // bfcvtn2 v16.8h, v20.4s\n" + "str q19, [%x[out_ptr], #0x0]\n" + "str q18, [%x[out_ptr], #0x10]\n" + "str q17, [%x[out_ptr], #0x20]\n" + "str q16, [%x[out_ptr], #0x30]\n" + "add %x[out_ptr], %x[out_ptr], #0x40\n" + "6:" // Odds skip + : [out_ptr] "+&r" (out), [width] "+&r" (width) + : [height] "r" (height), [in] "r" (in), [row_offset] "r" (row_offset) + : "cc", "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" + ); + } +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_lhs_pack_8x4_f32_bf16_neon.h b/kai/ukernels/matmul/pack/kai_lhs_pack_8x4_f32_bf16_neon.h new file mode 100644 index 00000000..5cd514c9 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_pack_8x4_f32_bf16_neon.h @@ -0,0 +1,31 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +#include +#include +#include "kai/kai_common.h" + +size_t kai_get_m_step_lhs_pack_8x4_f32_bf16_neon(size_t mr); + +size_t kai_get_lhs_offset_lhs_pack_8x4_f32_bf16_neon(size_t m_idx, size_t lhs_stride); + +size_t kai_get_lhs_packed_offset_lhs_pack_8x4_f32_bf16_neon(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr); + +size_t kai_get_lhs_packed_size_lhs_pack_8x4_f32_bf16_neon(size_t m, size_t k, size_t mr, size_t kr, size_t sr); + +void kai_run_lhs_pack_8x4_f32_bf16_neon( + size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, + const void* lhs, size_t lhs_stride, void* lhs_packed +); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus \ No newline at end of file diff --git a/kai/ukernels/matmul/pack/kai_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12.c b/kai/ukernels/matmul/pack/kai_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12.c new file mode 100644 index 00000000..36f60c26 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12.c @@ -0,0 +1,474 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if !defined(__aarch64__) +#error This file must be compiled for AArch64. +#else // Architectural features check. + +#include + +#include +#include +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_nr = 12; +static const size_t kai_kr = 4; + +size_t kai_get_n_step_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12(void) { + return kai_nr; +} + +size_t kai_get_rhs_offset_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12(size_t n_idx) { + KAI_ASSUME(n_idx % kai_nr == 0); + + return n_idx * sizeof(float); +} + + +size_t kai_get_bias_offset_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12(size_t n_idx) { + return n_idx * sizeof(uint32_t); +} + + +size_t kai_get_rhs_packed_offset_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12(size_t n_idx, size_t k) { + KAI_ASSUME(n_idx % kai_nr == 0); + + return n_idx * (sizeof(uint32_t) + kai_roundup(k, kai_kr) * sizeof(bfloat16_t)); +} + +size_t kai_get_rhs_packed_size_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12(size_t n, size_t k) { + return kai_get_rhs_packed_offset_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12(kai_roundup(n, kai_nr), k); +} + +void kai_run_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12( + size_t num_groups, + size_t n, + size_t k, + size_t nr, + size_t kr, + size_t sr, + size_t rhs_stride, + const void* rhs, + const void* bias, + const void* scale, + void* rhs_packed, + size_t extra_bytes, + const void* params) { + KAI_ASSUME(num_groups == 1); + KAI_ASSUME(nr == kai_nr); + KAI_ASSUME(kr == kai_kr); + KAI_ASSUME(sr == 1); + KAI_ASSUME(rhs != NULL); + KAI_ASSUME(bias != NULL); + KAI_ASSUME(scale == NULL); + KAI_ASSUME(rhs_packed != NULL); + KAI_ASSUME(extra_bytes == 0); + KAI_ASSUME(params == NULL); + + size_t height = k; + const size_t width = n; + const void* in = rhs; + void* out = rhs_packed; + const size_t in_stride = rhs_stride; + float *pad_row = (float*)alloca(width * sizeof(float)); + + if (height % 4) { + memset(pad_row, 0, width * sizeof(float)); + } + + size_t out_stride = kai_nr * kai_roundup(height, 4) * sizeof(bfloat16_t) + kai_nr * sizeof(uint32_t); + + __asm__ __volatile__( + "mov x22, %x[width]\n" + "mov x21, %x[out]\n" + "cmp x22, #0xc\n" + "blt 2f\n" + "1:" // Bias: Full loop + "ldr q16, [%x[bias], #0x0]\n" + "ldr q26, [%x[bias], #0x10]\n" + "sub x22, x22, #0xc\n" + "ldr q8, [%x[bias], #0x20]\n" + "cmp x22, #0xc\n" + "add %x[bias], %x[bias], #0x30\n" + "str q16, [x21, #0x0]\n" + "str q26, [x21, #0x10]\n" + "str q8, [x21, #0x20]\n" + "add x21, x21, %x[out_stride]\n" + "bge 1b\n" + "cbz x22, 3f\n" + "2:" // Bias: Tail loop + "ldr w20, [%x[bias], #0x0]\n" + "sub x22, x22, #0x1\n" + "add %x[bias], %x[bias], #0x4\n" + "cmp x22, #0x0\n" + "str w20, [x21]\n" + "add x21, x21, #0x4\n" + "bgt 2b\n" + "3:" // Bias: Done + "cmp %x[height], #0x8\n" + "add %x[out], %x[out], #0x30\n" + "blt 12f\n" + "4:" // Main row loop: Head + "mov x9, %x[in]\n" + "mov x28, %x[width]\n" + "mov x27, %x[out]\n" + "sub %x[height], %x[height], #0x8\n" + "add x26, x9, %x[in_stride]\n" + "add x25, x26, %x[in_stride]\n" + "add x24, x25, %x[in_stride]\n" + "cmp x28, #0xc\n" + "add x23, x24, %x[in_stride]\n" + "add x22, x23, %x[in_stride]\n" + "add x21, x22, %x[in_stride]\n" + "add x20, x21, %x[in_stride]\n" + "add %x[in], x20, %x[in_stride]\n" + "blt 6f\n" + "5:" // Main row loop: Column loop + "ldr q28, [x9], #0x10\n" + "ldr q27, [x26], #0x10\n" + "sub x28, x28, #0xc\n" + "ldr q11, [x25], #0x10\n" + "ldr q5, [x24], #0x10\n" + "cmp x28, #0xc\n" + "ldr q14, [x23], #0x10\n" + "ldr q6, [x22], #0x10\n" + "ldr q2, [x21], #0x10\n" + "ldr q18, [x20], #0x10\n" + "ldr q1, [x9], #0x10\n" + "ldr q7, [x26], #0x10\n" + "zip1 v15.4s, v28.4s, v11.4s\n" + "zip1 v8.4s, v27.4s, v5.4s\n" + "ldr q3, [x25], #0x10\n" + "ldr q23, [x24], #0x10\n" + "zip2 v17.4s, v28.4s, v11.4s\n" + "zip2 v27.4s, v27.4s, v5.4s\n" + "ldr q5, [x23], #0x10\n" + "ldr q30, [x22], #0x10\n" + "zip1 v26.4s, v14.4s, v2.4s\n" + "zip1 v31.4s, v6.4s, v18.4s\n" + "ldr q20, [x21], #0x10\n" + "ldr q16, [x20], #0x10\n" + "zip2 v12.4s, v14.4s, v2.4s\n" + "zip2 v24.4s, v6.4s, v18.4s\n" + "ldr q29, [x9], #0x10\n" + "ldr q6, [x26], #0x10\n" + "zip1 v18.4s, v1.4s, v3.4s\n" + "zip1 v4.4s, v7.4s, v23.4s\n" + "ldr q22, [x25], #0x10\n" + "ldr q0, [x24], #0x10\n" + "zip2 v3.4s, v1.4s, v3.4s\n" + "zip2 v1.4s, v7.4s, v23.4s\n" + "ldr q2, [x23], #0x10\n" + "ldr q10, [x22], #0x10\n" + "zip1 v28.4s, v5.4s, v20.4s\n" + "zip1 v14.4s, v30.4s, v16.4s\n" + "ldr q9, [x21], #0x10\n" + "ldr q23, [x20], #0x10\n" + "zip2 v13.4s, v5.4s, v20.4s\n" + "zip2 v30.4s, v30.4s, v16.4s\n" + "zip1 v16.4s, v29.4s, v22.4s\n" + "zip1 v5.4s, v6.4s, v0.4s\n" + "zip2 v22.4s, v29.4s, v22.4s\n" + "zip2 v0.4s, v6.4s, v0.4s\n" + "zip1 v7.4s, v2.4s, v9.4s\n" + "zip1 v19.4s, v10.4s, v23.4s\n" + "zip2 v21.4s, v2.4s, v9.4s\n" + "zip2 v25.4s, v10.4s, v23.4s\n" + "zip1 v11.4s, v15.4s, v8.4s\n" + "zip1 v9.4s, v17.4s, v27.4s\n" + "zip1 v6.4s, v18.4s, v4.4s\n" + "zip1 v2.4s, v3.4s, v1.4s\n" + "zip1 v29.4s, v16.4s, v5.4s\n" + "zip1 v20.4s, v22.4s, v0.4s\n" + "zip1 v10.4s, v26.4s, v31.4s\n" + "zip1 v23.4s, v12.4s, v24.4s\n" + ".inst 0x0ea1696b // bfcvtn v11.4h, v11.4s\n" + "zip2 v8.4s, v15.4s, v8.4s\n" + "zip1 v15.4s, v28.4s, v14.4s\n" + ".inst 0x0ea16929 // bfcvtn v9.4h, v9.4s\n" + "zip2 v27.4s, v17.4s, v27.4s\n" + "zip1 v17.4s, v13.4s, v30.4s\n" + ".inst 0x0ea168c6 // bfcvtn v6.4h, v6.4s\n" + "zip2 v4.4s, v18.4s, v4.4s\n" + "zip1 v18.4s, v7.4s, v19.4s\n" + ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n" + "zip2 v1.4s, v3.4s, v1.4s\n" + "zip1 v3.4s, v21.4s, v25.4s\n" + ".inst 0x0ea16bbd // bfcvtn v29.4h, v29.4s\n" + "zip2 v5.4s, v16.4s, v5.4s\n" + ".inst 0x0ea16a94 // bfcvtn v20.4h, v20.4s\n" + "zip2 v0.4s, v22.4s, v0.4s\n" + ".inst 0x0ea16956 // bfcvtn v22.4h, v10.4s\n" + "zip2 v31.4s, v26.4s, v31.4s\n" + ".inst 0x0ea16aea // bfcvtn v10.4h, v23.4s\n" + "zip2 v26.4s, v12.4s, v24.4s\n" + ".inst 0x0ea169ef // bfcvtn v15.4h, v15.4s\n" + "zip2 v12.4s, v28.4s, v14.4s\n" + ".inst 0x0ea16a2e // bfcvtn v14.4h, v17.4s\n" + "zip2 v24.4s, v13.4s, v30.4s\n" + ".inst 0x0ea16a57 // bfcvtn v23.4h, v18.4s\n" + "zip2 v18.4s, v7.4s, v19.4s\n" + ".inst 0x0ea16871 // bfcvtn v17.4h, v3.4s\n" + "zip2 v16.4s, v21.4s, v25.4s\n" + ".inst 0x4ea1690b // bfcvtn2 v11.8h, v8.4s\n" + ".inst 0x4ea16b69 // bfcvtn2 v9.8h, v27.4s\n" + ".inst 0x4ea16886 // bfcvtn2 v6.8h, v4.4s\n" + ".inst 0x4ea16822 // bfcvtn2 v2.8h, v1.4s\n" + ".inst 0x4ea168bd // bfcvtn2 v29.8h, v5.4s\n" + ".inst 0x4ea16814 // bfcvtn2 v20.8h, v0.4s\n" + ".inst 0x4ea16bf6 // bfcvtn2 v22.8h, v31.4s\n" + ".inst 0x4ea16b4a // bfcvtn2 v10.8h, v26.4s\n" + "str q11, [x27, #0x0]\n" + ".inst 0x4ea1698f // bfcvtn2 v15.8h, v12.4s\n" + ".inst 0x4ea16b0e // bfcvtn2 v14.8h, v24.4s\n" + "str q9, [x27, #0x10]\n" + ".inst 0x4ea16a57 // bfcvtn2 v23.8h, v18.4s\n" + ".inst 0x4ea16a11 // bfcvtn2 v17.8h, v16.4s\n" + "str q6, [x27, #0x20]\n" + "str q2, [x27, #0x30]\n" + "str q29, [x27, #0x40]\n" + "str q20, [x27, #0x50]\n" + "str q22, [x27, #0x60]\n" + "str q10, [x27, #0x70]\n" + "str q15, [x27, #0x80]\n" + "str q14, [x27, #0x90]\n" + "str q23, [x27, #0xa0]\n" + "str q17, [x27, #0xb0]\n" + "add x27, x27, %x[out_stride]\n" + "bge 5b\n" + "6:" // Main row loop: Column loop skip + "cbz x28, 11f\n" + "cmp x28, #0x4\n" + "movi v16.16b, #0x0\n" + "str q16, [x27, #0x0]\n" + "str q16, [x27, #0x10]\n" + "str q16, [x27, #0x20]\n" + "str q16, [x27, #0x30]\n" + "str q16, [x27, #0x40]\n" + "str q16, [x27, #0x50]\n" + "str q16, [x27, #0x60]\n" + "str q16, [x27, #0x70]\n" + "str q16, [x27, #0x80]\n" + "str q16, [x27, #0x90]\n" + "str q16, [x27, #0xa0]\n" + "str q16, [x27, #0xb0]\n" + "blt 8f\n" + "7:" // Main row loop: width 4 loop: loop + "ldr q25, [x9], #0x10\n" + "ldr q24, [x26], #0x10\n" + "sub x28, x28, #0x4\n" + "ldr q21, [x25], #0x10\n" + "ldr q20, [x24], #0x10\n" + "cmp x28, #0x4\n" + "ldr q23, [x23], #0x10\n" + "ldr q19, [x22], #0x10\n" + "ldr q18, [x21], #0x10\n" + "ldr q17, [x20], #0x10\n" + "zip1 v22.4s, v25.4s, v21.4s\n" + "zip1 v16.4s, v24.4s, v20.4s\n" + "zip2 v21.4s, v25.4s, v21.4s\n" + "zip2 v20.4s, v24.4s, v20.4s\n" + "zip1 v27.4s, v23.4s, v18.4s\n" + "zip1 v26.4s, v19.4s, v17.4s\n" + "zip2 v25.4s, v23.4s, v18.4s\n" + "zip2 v24.4s, v19.4s, v17.4s\n" + "zip1 v19.4s, v22.4s, v16.4s\n" + "zip1 v18.4s, v21.4s, v20.4s\n" + "zip1 v17.4s, v27.4s, v26.4s\n" + "zip2 v23.4s, v22.4s, v16.4s\n" + "zip1 v16.4s, v25.4s, v24.4s\n" + "zip2 v22.4s, v21.4s, v20.4s\n" + ".inst 0x0ea16a75 // bfcvtn v21.4h, v19.4s\n" + ".inst 0x0ea16a54 // bfcvtn v20.4h, v18.4s\n" + ".inst 0x0ea16a33 // bfcvtn v19.4h, v17.4s\n" + "zip2 v18.4s, v27.4s, v26.4s\n" + ".inst 0x0ea16a11 // bfcvtn v17.4h, v16.4s\n" + "zip2 v16.4s, v25.4s, v24.4s\n" + ".inst 0x4ea16af5 // bfcvtn2 v21.8h, v23.4s\n" + ".inst 0x4ea16ad4 // bfcvtn2 v20.8h, v22.4s\n" + ".inst 0x4ea16a53 // bfcvtn2 v19.8h, v18.4s\n" + ".inst 0x4ea16a11 // bfcvtn2 v17.8h, v16.4s\n" + "str q21, [x27, #0x0]\n" + "str q20, [x27, #0x10]\n" + "str q19, [x27, #0x60]\n" + "str q17, [x27, #0x70]\n" + "add x27, x27, #0x20\n" + "bge 7b\n" + "8:" // Main row loop: width 4 loop: skip + "cmp x28, #0x1\n" + "blt 10f\n" + "9:" // Main row loop: width 1 loop: loop + "ldr s23, [x9], #0x4\n" + "ldr s22, [x26], #0x4\n" + "sub x28, x28, #0x1\n" + "ldr s19, [x25], #0x4\n" + "ldr s17, [x24], #0x4\n" + "cmp x28, #0x1\n" + "ldr s21, [x23], #0x4\n" + "ldr s20, [x22], #0x4\n" + "ldr s18, [x21], #0x4\n" + "ldr s16, [x20], #0x4\n" + "zip1 v19.4s, v23.4s, v19.4s\n" + "zip1 v17.4s, v22.4s, v17.4s\n" + "zip1 v18.4s, v21.4s, v18.4s\n" + "zip1 v16.4s, v20.4s, v16.4s\n" + "zip1 v17.4s, v19.4s, v17.4s\n" + "zip1 v16.4s, v18.4s, v16.4s\n" + ".inst 0x0ea16a31 // bfcvtn v17.4h, v17.4s\n" + ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n" + "str d17, [x27, #0x0]\n" + "str d16, [x27, #0x60]\n" + "add x27, x27, #0x8\n" + "bge 9b\n" + "10:" // Main row loop: width 1 loop: skip + "11:" // Main row loop: odd col skip + "cmp %x[height], #0x8\n" + "add %x[out], %x[out], #0xc0\n" + "bge 4b\n" + "cbz %x[height], 21f\n" + "12:" // Main loop skip + "13:" // Tail row loop: Head + "mov x9, %x[in]\n" + "mov x20, %x[width]\n" + "cmp %x[height], #0x3\n" + "mov x27, %x[out]\n" + "add x26, x9, %x[in_stride]\n" + "add x25, x26, %x[in_stride]\n" + "add x24, x25, %x[in_stride]\n" + "csel x25, x25, %x[pad_row], GE\n" + "add %x[in], x24, %x[in_stride]\n" + "csel x24, x24, %x[pad_row], GT\n" + "cmp %x[height], #0x1\n" + "sub %x[height], %x[height], #0x4\n" + "csel x26, x26, %x[pad_row], GT\n" + "cmp x20, #0xc\n" + "blt 15f\n" + "14:" // Tail row loop: Column loop + "ldr q24, [x9], #0x10\n" + "ldr q23, [x26], #0x10\n" + "sub x20, x20, #0xc\n" + "ldr q22, [x25], #0x10\n" + "ldr q16, [x24], #0x10\n" + "cmp x20, #0xc\n" + "ldr q28, [x9], #0x10\n" + "ldr q27, [x26], #0x10\n" + "ldr q21, [x25], #0x10\n" + "ldr q20, [x24], #0x10\n" + "ldr q19, [x9], #0x10\n" + "zip1 v26.4s, v24.4s, v22.4s\n" + "zip1 v25.4s, v23.4s, v16.4s\n" + "ldr q18, [x26], #0x10\n" + "ldr q17, [x25], #0x10\n" + "zip2 v24.4s, v24.4s, v22.4s\n" + "zip2 v23.4s, v23.4s, v16.4s\n" + "ldr q16, [x24], #0x10\n" + "zip1 v2.4s, v28.4s, v21.4s\n" + "zip1 v22.4s, v27.4s, v20.4s\n" + "zip2 v1.4s, v28.4s, v21.4s\n" + "zip2 v0.4s, v27.4s, v20.4s\n" + "zip1 v31.4s, v19.4s, v17.4s\n" + "zip1 v30.4s, v18.4s, v16.4s\n" + "zip2 v29.4s, v19.4s, v17.4s\n" + "zip2 v28.4s, v18.4s, v16.4s\n" + "zip1 v21.4s, v26.4s, v25.4s\n" + "zip1 v20.4s, v24.4s, v23.4s\n" + "zip1 v19.4s, v2.4s, v22.4s\n" + "zip1 v18.4s, v1.4s, v0.4s\n" + "zip1 v17.4s, v31.4s, v30.4s\n" + "zip1 v16.4s, v29.4s, v28.4s\n" + ".inst 0x0ea16abb // bfcvtn v27.4h, v21.4s\n" + "zip2 v26.4s, v26.4s, v25.4s\n" + ".inst 0x0ea16a99 // bfcvtn v25.4h, v20.4s\n" + "zip2 v24.4s, v24.4s, v23.4s\n" + ".inst 0x0ea16a77 // bfcvtn v23.4h, v19.4s\n" + "zip2 v22.4s, v2.4s, v22.4s\n" + ".inst 0x0ea16a55 // bfcvtn v21.4h, v18.4s\n" + "zip2 v20.4s, v1.4s, v0.4s\n" + ".inst 0x0ea16a33 // bfcvtn v19.4h, v17.4s\n" + "zip2 v18.4s, v31.4s, v30.4s\n" + ".inst 0x0ea16a11 // bfcvtn v17.4h, v16.4s\n" + "zip2 v16.4s, v29.4s, v28.4s\n" + ".inst 0x4ea16b5b // bfcvtn2 v27.8h, v26.4s\n" + ".inst 0x4ea16b19 // bfcvtn2 v25.8h, v24.4s\n" + ".inst 0x4ea16ad7 // bfcvtn2 v23.8h, v22.4s\n" + ".inst 0x4ea16a95 // bfcvtn2 v21.8h, v20.4s\n" + ".inst 0x4ea16a53 // bfcvtn2 v19.8h, v18.4s\n" + ".inst 0x4ea16a11 // bfcvtn2 v17.8h, v16.4s\n" + "str q27, [x27, #0x0]\n" + "str q25, [x27, #0x10]\n" + "str q23, [x27, #0x20]\n" + "str q21, [x27, #0x30]\n" + "str q19, [x27, #0x40]\n" + "str q17, [x27, #0x50]\n" + "add x27, x27, %x[out_stride]\n" + "bge 14b\n" + "15:" // Tail row loop: Column loop skip + "cbz x20, 20f\n" + "cmp x20, #0x4\n" + "movi v16.16b, #0x0\n" + "str q16, [x27, #0x0]\n" + "str q16, [x27, #0x10]\n" + "str q16, [x27, #0x20]\n" + "str q16, [x27, #0x30]\n" + "str q16, [x27, #0x40]\n" + "str q16, [x27, #0x50]\n" + "blt 17f\n" + "16:" // Tail row loop: width 4 loop: loop + "ldr q21, [x9], #0x10\n" + "ldr q20, [x26], #0x10\n" + "sub x20, x20, #0x4\n" + "ldr q19, [x25], #0x10\n" + "ldr q17, [x24], #0x10\n" + "cmp x20, #0x4\n" + "zip1 v18.4s, v21.4s, v19.4s\n" + "zip1 v16.4s, v20.4s, v17.4s\n" + "zip2 v21.4s, v21.4s, v19.4s\n" + "zip2 v20.4s, v20.4s, v17.4s\n" + "zip1 v17.4s, v18.4s, v16.4s\n" + "zip2 v19.4s, v18.4s, v16.4s\n" + "zip1 v16.4s, v21.4s, v20.4s\n" + ".inst 0x0ea16a32 // bfcvtn v18.4h, v17.4s\n" + "zip2 v17.4s, v21.4s, v20.4s\n" + ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n" + ".inst 0x4ea16a72 // bfcvtn2 v18.8h, v19.4s\n" + ".inst 0x4ea16a30 // bfcvtn2 v16.8h, v17.4s\n" + "str q18, [x27, #0x0]\n" + "str q16, [x27, #0x10]\n" + "add x27, x27, #0x20\n" + "bge 16b\n" + "17:" // Tail row loop: width 4 loop: skip + "cmp x20, #0x1\n" + "blt 19f\n" + "18:" // Tail row loop: width 1 loop: loop + "ldr s19, [x9], #0x4\n" + "ldr s18, [x26], #0x4\n" + "sub x20, x20, #0x1\n" + "ldr s17, [x25], #0x4\n" + "ldr s16, [x24], #0x4\n" + "cmp x20, #0x1\n" + "zip1 v17.4s, v19.4s, v17.4s\n" + "zip1 v16.4s, v18.4s, v16.4s\n" + "zip1 v16.4s, v17.4s, v16.4s\n" + ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n" + "str d16, [x27, #0x0]\n" + "add x27, x27, #0x8\n" + "bge 18b\n" + "19:" // Tail row loop: width 1 loop: skip + "20:" // Tail row loop: odd col skip + "cmp %x[height], #0x1\n" + "add %x[out], %x[out], #0x60\n" + "bge 13b\n" + "21:" // Done + : [bias] "+&r" (bias), [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" + ); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12.h b/kai/ukernels/matmul/pack/kai_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12.h new file mode 100644 index 00000000..c12f253f --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12.h @@ -0,0 +1,93 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Gets n step value. +/// +/// The starting row index must be divisible by `n_step`. +/// +/// @return The n step value. +size_t kai_get_n_step_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12(void); + +/// Gets the offset in bytes to the data element in the RHS matrix buffer. +/// +/// @param[in] n_idx Column index. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_offset_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12(size_t n_idx); + + +/// Gets the offset in bytes to the data element in the bias buffer. +/// +/// @param[in] n_idx Column index. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_bias_offset_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12(size_t n_idx); + + +/// Gets the offset in bytes to the data element in the packed RHS buffer. +/// +/// @param[in] n_idx Row index. +/// @param[in] k Number of columns. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_packed_offset_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12(size_t n_idx, size_t k); + +/// Gets the size in bytes of the packed RHS buffer. +/// +/// @param[in] n Number of rows. +/// @param[in] k Number of columns. +/// +/// @return The size in bytes of the packed RHS buffer. +size_t kai_get_rhs_packed_size_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12(size_t n, size_t k); + +/// Runs the RHS packing function for matrix multiplication. +/// +/// The pointer of each buffers (RHS, bias and packed RHS) needs to be added with offset +/// calculated using the following functions: +/// +/// * RHS: @ref kai_get_rhs_offset_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12. +/// * Bias: @ref kai_get_bias_offset_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12. +/// * Output: @ref kai_get_rhs_packed_offset_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12. +/// +/// @param[in] num_groups Number of groups. It must be 1. +/// @param[in] n Number of columns of the output matrix. +/// @param[in] k Common dimension between the LHS and RHS matrix. +/// @param[in] nr Block size in N dimension. It must be 12. +/// @param[in] kr Block size in K dimension. It must be 4. +/// @param[in] sr Number of kr splits. It must be 1. +/// @param[in] rhs_stride Row stride in bytes of the RHS matrix. +/// @param[in] rhs RHS matrix data buffer. +/// @param[in] bias Bias matrix data buffer. +/// @param[in] scale Scale data buffer. It must be NULL. +/// @param[out] rhs_packed Packed RHS matrix. +/// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. It must be 0. +/// @param[in] params Extra packing parameters. It must be NULL. +void kai_run_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12( + size_t num_groups, + size_t n, + size_t k, + size_t nr, + size_t kr, + size_t sr, + size_t rhs_stride, + const void* rhs, + const void* bias, + const void* scale, + void* rhs_packed, + size_t extra_bytes, + const void* params); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus -- GitLab From af16adb78ad7bd29789825f1e3f3683aeb2de3c8 Mon Sep 17 00:00:00 2001 From: Gunes Bayir Date: Sat, 28 Sep 2024 22:19:24 +0300 Subject: [PATCH 02/10] Refactor kernel names and add tests This commit - adds the unit tests for bf16 interleaved gemm kernel - refactors matmul_test.cpp so that its data types and the infrastructure can be used in other files - adds an example for bf16 interleaved gemm kernel Signed-off-by: Gunes Bayir --- CMakeLists.txt | 8 +- .../CMakeLists.txt | 34 + .../matmul_clamp_f32_bf16p_bf16p.cpp | 331 ++++++++++ kai/ukernels/matmul/BUILD.bazel | 26 +- ...16_bf16_f32p12x1biasf32_8x12x4_neon_mmla.c | 596 ------------------ .../matmul_clamp_bf16_bf16_f32p_interface.h | 57 -- ..._bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.c | 583 +++++++++++++++++ ...bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.h} | 45 +- .../matmul_clamp_f32_bf16p_bf16p_interface.h | 57 ++ .../pack/kai_lhs_pack_8x4_f32_bf16_neon.c | 222 ------- .../pack/kai_lhs_pack_f32p8x4_bf16_neon.c | 212 +++++++ ...eon.h => kai_lhs_pack_f32p8x4_bf16_neon.h} | 18 +- ...s_bias_bf16p16x4zf32_bf16_f32_neon_nr_12.c | 474 -------------- ...s_pack_kxn_f32p4x12biasf32_f32_bf16_neon.c | 461 ++++++++++++++ ..._pack_kxn_f32p4x12biasf32_f32_bf16_neon.h} | 35 +- test/common/MatMulMethod.hpp | 330 ++++++++++ test/common/bfloat16.hpp | 8 + test/common/compare.cpp | 3 + test/common/matmul_test_common.cpp | 25 + test/common/matmul_test_common.hpp | 26 + test/common/memory.hpp | 12 + test/reference/pack.cpp | 149 +++-- .../matmul_clamp_f32_bf16p_bf16p_test.cpp | 328 ++++++++++ test/tests/matmul_test.cpp | 335 +--------- 24 files changed, 2584 insertions(+), 1791 deletions(-) create mode 100644 examples/matmul_clamp_f32_bf16p_bf16p/CMakeLists.txt create mode 100644 examples/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p.cpp delete mode 100644 kai/ukernels/matmul/matmul_clamp_bf16_bf16_f32p/kai_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla.c delete mode 100644 kai/ukernels/matmul/matmul_clamp_bf16_bf16_f32p/matmul_clamp_bf16_bf16_f32p_interface.h create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.c rename kai/ukernels/matmul/{matmul_clamp_bf16_bf16_f32p/kai_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla.h => matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.h} (64%) create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p_interface.h delete mode 100644 kai/ukernels/matmul/pack/kai_lhs_pack_8x4_f32_bf16_neon.c create mode 100644 kai/ukernels/matmul/pack/kai_lhs_pack_f32p8x4_bf16_neon.c rename kai/ukernels/matmul/pack/{kai_lhs_pack_8x4_f32_bf16_neon.h => kai_lhs_pack_f32p8x4_bf16_neon.h} (51%) delete mode 100644 kai/ukernels/matmul/pack/kai_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12.c create mode 100644 kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon.c rename kai/ukernels/matmul/pack/{kai_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12.h => kai_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon.h} (64%) create mode 100644 test/common/MatMulMethod.hpp create mode 100644 test/common/matmul_test_common.cpp create mode 100644 test/common/matmul_test_common.hpp create mode 100644 test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 02f535e6..569b71b0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -89,9 +89,9 @@ set(KLEIDIAI_FILES_NEON_FP16 ) set(KLEIDIAI_FILES_NEON_BF16 - kai/ukernels/matmul/pack/kai_lhs_pack_8x4_f32_bf16_neon.c - kai/ukernels/matmul/pack/kai_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12.c - kai/ukernels/matmul/matmul_clamp_bf16_bf16_f32p/kai_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla.c + kai/ukernels/matmul/pack/kai_lhs_pack_f32p8x4_bf16_neon.c + kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon.c + kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.c ) set(KLEIDIAI_FILES_NEON @@ -178,6 +178,7 @@ if(KLEIDIAI_BUILD_TESTS) test/common/printer.cpp test/common/int4.cpp test/common/compare.cpp + test/common/matmul_test_common.cpp test/common/matrix_portion.cpp test/common/rect.cpp test/common/round.cpp @@ -213,6 +214,7 @@ if(KLEIDIAI_BUILD_TESTS) test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp test/tests/matmul_clamp_f32_qsi8d32p_qsi4c32p_test.cpp test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp + test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp ) target_link_libraries(kleidiai_test diff --git a/examples/matmul_clamp_f32_bf16p_bf16p/CMakeLists.txt b/examples/matmul_clamp_f32_bf16p_bf16p/CMakeLists.txt new file mode 100644 index 00000000..62007cf7 --- /dev/null +++ b/examples/matmul_clamp_f32_bf16p_bf16p/CMakeLists.txt @@ -0,0 +1,34 @@ +# +# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +# +# SPDX-License-Identifier: Apache-2.0 +# + +cmake_minimum_required(VERSION 3.16) + +set(CMAKE_CXX_STANDARD 17) +set(KLEIDIAI_PATH ../../) +set(MATMUL_PACK_PATH ${KLEIDIAI_PATH}/kai/ukernels/matmul/pack/) +set(MATMUL_PATH ${KLEIDIAI_PATH}/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/) + +# KleidiAI include directories +include_directories( + ${KLEIDIAI_PATH} + ${MATMUL_PACK_PATH} + ${MATMUL_PATH}) + +# Files requires to build the executable +add_executable(matmul_clamp_f32_bf16p_bf16p + matmul_clamp_f32_bf16p_bf16p.cpp + ${MATMUL_PATH}/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.c + ${MATMUL_PACK_PATH}/kai_lhs_pack_f32p8x4_bf16_neon.c + ${MATMUL_PACK_PATH}/kai_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon.c +) + +target_compile_options(matmul_clamp_f32_bf16p_bf16p + PRIVATE -march=armv8.2-a+bf16 +) + +target_compile_definitions(matmul_clamp_f32_bf16p_bf16p + PRIVATE $<$:KAI_DEBUG> +) diff --git a/examples/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p.cpp b/examples/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p.cpp new file mode 100644 index 00000000..49f5a1f9 --- /dev/null +++ b/examples/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p.cpp @@ -0,0 +1,331 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +// Example usage for matrix multiplication of two half precision floating-point (FP16) matrices and the accumulation of +// the result into an FP16 destination matrix. +// +// The activations and the weights, stored in the LHS and RHS matrices respectively, are both non-transposed matrices. +// The matrix multiplication computation is performed using floating-point fused multiply-add to accumulator (FMLA) +// vector instructions present in the FEAT_FP16 Arm® architecture feature. +// +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_FP16_SCALAR_ARITHMETIC) || \ + !defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#error This file must be compiled for AArch64, FEAT_FP16. +#else +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +// Include micro-kernel variants +#include "kai_lhs_pack_f32p8x4_bf16_neon.h" +#include "kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.h" +#include "kai_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon.h" +#include "matmul_clamp_f32_bf16p_bf16p_interface.h" + +inline float bf16_to_float(uint16_t v) { + const uint32_t lv = (v << 16); + float fp; + memcpy(&fp, &lv, sizeof(lv)); + return fp; +} + +inline float bf16_to_float(const bfloat16_t* v) { + const uint16_t uint_rep = *reinterpret_cast(v); + return bf16_to_float(uint_rep); +} + +namespace { +/// Micro-kernel interface +constexpr kai_matmul_clamp_f32_bf16p_bf16p_ukernel ukernel{ + kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + kai_get_lhs_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + kai_run_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla}; + +float truncate(float x) { + uint32_t uval = (*reinterpret_cast(&x) & 0xffff0000); + return *reinterpret_cast(&uval); +} + +/// Reference implementation of matrix multiplication +void run_matmul_ref( + size_t m, size_t n, size_t k, const float* lhs, const float* rhs, const float* bias, float* dst, float scalar_min, + float scalar_max) { + for (size_t row_idx = 0; row_idx < m; ++row_idx) { + for (size_t col_idx = 0; col_idx < n; ++col_idx) { + float acc = bias[col_idx]; + + for (size_t k_idx = 0; k_idx < k; ++k_idx) { + float lhs_val = truncate(lhs[row_idx * k + k_idx]); + float rhs_val = truncate(rhs[col_idx + n * k_idx]); + + acc += lhs_val * rhs_val; + } + acc = std::max(acc, scalar_min); + acc = std::min(acc, scalar_max); + + dst[row_idx * n + col_idx] = acc; + } + } +} + +/// Fills the matrix with incremental values +void fill_matrix(size_t num_rows, size_t num_cols, float* dst, const float weight) { + for (size_t i = 0; i < num_rows * num_cols; i++) { + dst[i] = float((i + 1) * weight); + } +} + +void fill_identity(size_t num_rows, size_t num_cols, float* dst, const float weight) { + for (size_t i = 0; i < num_rows * num_cols; i++) { + int col = i % num_cols; + int row = i / num_cols; + + dst[i] = (col == row ? 1.f : 0.f); + } +} + +/// Print the matrix +void print_matrix(size_t num_rows, size_t num_cols, const char* name, const float* src) { + std::cout << name << " = [\n"; + for (size_t y = 0; y < num_rows; ++y) { + std::cout << " ["; + for (size_t x = 0; x < num_cols; ++x) { + std::cout << std::setprecision(2) << std::fixed << src[y * num_cols + x] << ", "; + } + std::cout << ("],\n"); + } + std::cout << ("]\n\n"); +} + +void print_matrix(size_t num_rows, size_t num_cols, const char* name, const bfloat16_t* src) { + std::cout << name << " = [\n"; + for (size_t y = 0; y < num_rows; ++y) { + std::cout << " ["; + for (size_t x = 0; x < num_cols; ++x) { + std::cout << std::setprecision(2) << std::fixed << bf16_to_float(&src[y * num_cols + x]) << ", "; + } + std::cout << ("],\n"); + } + std::cout << ("]\n\n"); +} + +void print_mixed_prec_matrix( + size_t num_rows, size_t num_cols, const char* name, const uint8_t* src, int nr, int stride) { + std::cout << name << " = [\n"; + for (size_t y = 0; y < num_rows; ++y) { + std::cout << " ["; + const uint8_t* src_row = src + stride * y; + for (size_t x = 0; x < num_cols; ++x) { + if (x >= nr) { + // print bfloat + const bfloat16_t* src_elm = + reinterpret_cast(src_row + nr * sizeof(float) + (x - nr) * sizeof(bfloat16_t)); + std::cout << std::setprecision(2) << std::fixed << bf16_to_float(src_elm) << ", "; + } else { + // print float + const float* src_elm = reinterpret_cast(src_row + x * sizeof(float)); + std::cout << std::setprecision(2) << std::fixed << *src_elm << ", "; + } + } + std::cout << ("],\n"); + } + std::cout << ("]\n\n"); +} + +void print_bf_matrix(size_t num_rows, size_t num_cols, const char* name, const float* src) { + std::cout << name << " = [\n"; + for (size_t y = 0; y < num_rows; ++y) { + std::cout << " ["; + for (size_t x = 0; x < num_cols; ++x) { + std::cout << std::setprecision(2) << std::fixed << truncate(src[y * num_cols + x]) << ", "; + } + std::cout << ("],\n"); + } + std::cout << ("]\n\n"); +} + +/// Verify the micro-kernel output matches the reference implementation +bool is_output_correct( + size_t num_rows, size_t num_cols, const float rel_tolerance, const float* ref, const float* act) { + bool is_valid = true; + + for (size_t i = 0; i < num_rows * num_cols; ++i) { + if (std::fabs(ref[i] - act[i]) / (act[i] + 1e-10) > rel_tolerance) { + const size_t x = i % num_cols; + const size_t y = i / num_cols; + + std::cout << std::setprecision(5) << std::fixed << "ERROR![" << y << "][" << x << "]: ref=" << ref[i] + << " vs. act=" << act[i] << "\n"; + + is_valid = false; + } + } + return is_valid; +} +} // namespace + +int main() { + // Parameters of the matrix multiplication. Change these values to see how the micro-kernels operate on different + // sized matrices + const size_t M = 5; // Rows of LHS and DST matrices + const size_t N = 8; // Columns of RHS and DST matrices, and length of the Bias vector. + const size_t K = 7; // Columns of LHS, rows of RHS matrices + + const size_t lhs_size = M * K; + const size_t rhs_size = N * K; + const size_t bias_size = N; + const size_t dst_size = M * N; + + // Allocate the memory + float* lhs = new float[lhs_size]; + float* rhs = new float[rhs_size]; + float* bias = new float[bias_size]; + + fill_matrix(M, K, lhs, 0.1); + fill_matrix(K, N, rhs, 0.1); + fill_matrix(1, N, bias, 1); + +#ifdef KAI_DEBUG + // std::cout << "Floats: " << std::endl; + print_matrix(M, K, "lhs", lhs); + print_matrix(K, N, "rhs", rhs); + print_matrix(1, N, "bias", bias); + + // Print bf16 converted values + print_bf_matrix(M, K, "lhs_bf", lhs); + print_bf_matrix(K, N, "rhs_bf", rhs); +#endif // KAI_DEBUG + + //----------- REFERENCE IMPLEMENTATION + //------------------------------------ + //------------------------------------ + float* dst_ref = new float[dst_size]; + + run_matmul_ref( + M, N, K, // Dimensions + lhs, // LHS buffer + rhs, // RHS buffer + bias, // Bias buffer + dst_ref, // DST + FLT_MIN, FLT_MAX // Min and max for the clamp operation + ); + //----------- END REFERENCE IMPLEMENTATION + //------------------------------------ + //------------------------------------ + + //----------- MICRO-KERNELS TESTS + //------------------------------------ + //------------------------------------ + const size_t mr = ukernel.get_mr(); + const size_t nr = ukernel.get_nr(); + const size_t kr = ukernel.get_kr(); + const size_t sr = ukernel.get_sr(); + + // In a single row, we pack nr bias values followed by K rows of nr RHS values + const size_t rhs_packed_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon(N, K); + uint8_t* rhs_packed = new uint8_t[rhs_packed_size]; + + const size_t lhs_stride = K * sizeof(float); + const size_t rhs_stride = N * sizeof(float); + const size_t dst_stride_row = N * sizeof(float); + const size_t dst_stride_col = sizeof(float); + + const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_pack_f32p8x4_bf16_neon(M, K, mr, kr, sr); + bfloat16_t* lhs_packed = new bfloat16_t[lhs_packed_size]; + + // Packing only needs to be performed once if the contents of the bias and RHS matrices are expected to be constant. + kai_run_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon( + 1, N, K, nr, kr, sr, // Packing arguments + rhs_stride, // RHS stride + rhs, // RHS + bias, // Bias + NULL, // Scale + rhs_packed, // RHS packed + 0, NULL); + + // The RHS and Bias buffers can be freed after packing, however we reuse them for the reference test below + +#ifdef KAI_DEBUG + const size_t rhs_packed_cols = nr + kai_roundup(K, kr) * nr; + + // Each col has nr floats and then K*nr bfloats + int rhs_packed_stride = nr * sizeof(float) + kai_roundup(K, kr) * nr * sizeof(bfloat16_t); + const size_t rhs_packed_rows = rhs_packed_size / rhs_packed_stride; + + print_mixed_prec_matrix(rhs_packed_rows, rhs_packed_cols, "rhs_packed", rhs_packed, nr, rhs_packed_stride); +#endif // KAI_DEBUG + + float* dst = new float[dst_size]; + + kai_run_lhs_pack_f32p8x4_bf16_neon(M, K, mr, kr, sr, 0 /* m_idx_start */, lhs, lhs_stride, lhs_packed); + + const auto timer_matmul_start = std::chrono::high_resolution_clock::now(); + + ukernel.run_matmul( + M, N, K, // Dimensions + lhs_packed, // LHS packed + rhs_packed, // RHS packed + dst, // DST + dst_stride_row, // DST stride (row) + dst_stride_col, // DST stride (col) + FLT_MIN, FLT_MAX // Min and max for the clamp operation + ); + + const auto timer_matmul_end = std::chrono::high_resolution_clock::now(); + const auto time_matmul = + std::chrono::duration_cast(timer_matmul_end - timer_matmul_start); + +#ifdef KAI_DEBUG + int num_lhs_rows = (M + mr - 1) / mr; + int num_lhs_cols = mr * kai_roundup(K, kr); + + print_matrix(num_lhs_rows, num_lhs_cols, "lhs_packed", lhs_packed); + print_matrix(M, N, "dst", dst); + print_matrix(M, N, "ref", dst_ref); +#endif // KAI_DEBUG + + const bool is_valid = is_output_correct(M, N, 0.02 /* rel tol */, dst_ref, dst); + + std::cout << "TEST[matmul_clamp_f32_bf16p_bf16p]\n"; + std::cout << "- ukernel: matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla\n"; + if (is_valid) { + std::cout << "- Status: PASSED\n"; + std::cout << "- Performance: " << time_matmul.count() << "ns\n"; + } else { + std::cout << "- Status: FAILED\n"; + return 1; + } + + //----------- END MICRO-KERNELS TESTS + //------------------------------------ + //------------------------------------ + + delete[] lhs; + delete[] rhs; + delete[] bias; + delete[] rhs_packed; + delete[] dst; + delete[] dst_ref; + + return 0; +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index f9dc2c13..b2f6ab58 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -7,9 +7,9 @@ load( "//:kai_defs.bzl", "kai_c_library", + "kai_cpu_bf16", "kai_cpu_dotprod", "kai_cpu_fp16", - "kai_cpu_bf16", "kai_cpu_i8mm", "kai_cpu_neon", "kai_cpu_sme", @@ -34,18 +34,18 @@ kai_c_library( ) kai_c_library( - name = "clamp_bf16_bf16_f32p_interface", - hdrs = ["matmul_clamp_bf16_bf16_f32p/kai_matmul_clamp_bf16_bf16_f32p_interface.h"], + name = "clamp_f32_bf16p_bf16p_interface", + hdrs = ["matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h"], cpu_uarch = kai_cpu_bf16(), ) kai_c_library( - name = "clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla", - srcs = ["matmul_clamp_f16_f16_f16p/kai_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla.c"], - hdrs = ["matmul_clamp_f16_f16_f16p/kai_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla.h"], + name = "clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla", + srcs = ["matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.c"], + hdrs = ["matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.h"], cpu_uarch = kai_cpu_bf16(), deps = [ - ":clamp_bf16_bf16_f32p_interface", + ":clamp_f32_bf16p_bf16p_interface", ], ) @@ -177,9 +177,9 @@ kai_c_library( ) kai_c_library( - name = "lhs_pack_8x4_f32_bf16_neon", - srcs = ["pack/kai_lhs_pack_8x4_f32_bf16_neon.c"], - hdrs = ["pack/kai_lhs_pack_8x4_f32_bf16_neon.h"], + name = "lhs_pack_f32p8x4_bf16_neon", + srcs = ["pack/kai_lhs_pack_f32p8x4_bf16_neon.c"], + hdrs = ["pack/kai_lhs_pack_f32p8x4_bf16_neon.h"], cpu_uarch = kai_cpu_bf16(), ) @@ -191,9 +191,9 @@ kai_c_library( ) kai_c_library( - name = "matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12", - srcs = ["pack/kai_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12.c"], - hdrs = ["pack/kai_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12.h"], + name = "rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon", + srcs = ["pack/kai_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon.c"], + hdrs = ["pack/kai_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon.h"], cpu_uarch = kai_cpu_bf16(), ) diff --git a/kai/ukernels/matmul/matmul_clamp_bf16_bf16_f32p/kai_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla.c b/kai/ukernels/matmul/matmul_clamp_bf16_bf16_f32p/kai_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla.c deleted file mode 100644 index 8e845095..00000000 --- a/kai/ukernels/matmul/matmul_clamp_bf16_bf16_f32p/kai_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla.c +++ /dev/null @@ -1,596 +0,0 @@ -// -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates -// -// SPDX-License-Identifier: Apache-2.0 -// - -#if !defined(__aarch64__) || !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) -#error This file must be compiled for AArch64, FEAT_BF16. -#else // Architectural features check. - -#include -#include -#include -#include -#include - -typedef bfloat16_t bfloat16; - -#include "kai/kai_common.h" -#include "kai_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla.h" - -static const size_t kai_mr = 8; -static const size_t kai_nr = 12; -static const size_t kai_kr = 4; -static const size_t kai_sr = 1; - -size_t kai_get_m_step_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla(void) { - return kai_mr; -} - -size_t kai_get_n_step_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla(void) { - return kai_nr; -} - -size_t kai_get_mr_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla(void) { - return kai_mr; -} - -size_t kai_get_nr_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla(void) { - return kai_nr; -} - -size_t kai_get_kr_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla(void) { - return kai_kr; -} - -size_t kai_get_sr_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla(void) { - return kai_sr; -} - -size_t kai_get_lhs_offset_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla(size_t m_idx, size_t stride) { - KAI_ASSUME(m_idx % kai_mr == 0); - - return m_idx * stride; -} - -size_t kai_get_rhs_packed_offset_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla(size_t n_idx, size_t k) { - KAI_ASSUME(n_idx % kai_nr == 0); - - return n_idx / kai_nr * (kai_nr * sizeof(float) + kai_nr * k * sizeof(bfloat16)); -} - -size_t kai_get_dst_offset_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla(size_t m_idx, size_t n_idx, size_t stride) { - KAI_ASSUME(m_idx % kai_mr == 0); - KAI_ASSUME(n_idx % kai_nr == 0); - - return m_idx * stride + n_idx * sizeof(float); -} - -size_t kai_get_dst_size_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla(size_t m, size_t n) { - return m * n * sizeof(float); -} - -void kai_run_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla( - size_t m, size_t n, size_t k, // - const void* lhs_packed, size_t lhs_stride, // - const void* rhs_packed, // - void* dst, size_t dst_stride_row, size_t dst_stride_col, // - float clamp_min, float clamp_max) { - KAI_ASSERT(dst_stride_col == sizeof(float)); - - KAI_UNUSED(lhs_stride); - - const void *Apanel = lhs_packed; - // const void *Bpanel = rhs_packed; - void *Cpanel = dst; - size_t ldc = dst_stride_row / sizeof(float); - - size_t M = m; - - typedef struct { - float maxval; - float minval; - unsigned int num_strings; - const unsigned int* string_lengths; - size_t N; - size_t K; - const void* Bpanel; - size_t output_offset; - size_t input_initial_col; - size_t input_offset; - void* output_ptr; - const void* bias; - } KernelArgs; - - KernelArgs ka; - - unsigned int string_length = k; - ka.num_strings = 1; - ka.string_lengths = &string_length; - ka.N = n; - ka.K = kai_roundup(k, 4) / 4 - 1; - - ka.Bpanel = rhs_packed; - ka.bias = NULL; - - // Direct input. - // const void* input_ptr = lhs; - // ka.input_offset = lhs_stride / sizeof(bfloat16); - ka.input_initial_col = 0; - - // Direct output. - ka.output_ptr = dst; - // ka.output_offset = dst_stride_row / sizeof(float); - - // Clamping output. - ka.maxval = clamp_max; - ka.minval = clamp_min; - - __asm__ __volatile__( - "1:" // Height loop - "add x11, %x[Cpanel], %x[ldc], LSL #2\n" - "add x10, %x[Cpanel], %x[ldc], LSL #1\n" - "add x9, x11, %x[ldc], LSL #1\n" - "cmp %x[M], #0x8\n" - "add x28, %x[Cpanel], %x[ldc], LSL #3\n" - "add x27, %x[Cpanel], %x[ldc]\n" - "add x26, x10, %x[ldc]\n" - "add x25, x11, %x[ldc]\n" - "add x24, x9, %x[ldc]\n" - "bge 2f\n" - "cmp %x[M], #0x2\n" - "mov x24, %x[Cpanel]\n" - "csel x27, x27, %x[Cpanel], GE\n" - "csel x10, x10, %x[Cpanel], GT\n" - "cmp %x[M], #0x4\n" - "csel x26, x26, %x[Cpanel], GE\n" - "csel x11, x11, %x[Cpanel], GT\n" - "cmp %x[M], #0x6\n" - "csel x25, x25, %x[Cpanel], GE\n" - "csel x9, x9, %x[Cpanel], GT\n" - "2:" // all rows valid - "ldr x23, [%x[args_ptr], %[offsetof_N]]\n" - "ldr x22, [%x[args_ptr], %[offsetof_Bpanel]]\n" - "mov x21, %x[Apanel]\n" - "3:" // Width loop - "ldr q4, [x22, #0x0]\n" - "ldr q5, [x22, #0x10]\n" - "mov %x[Apanel], x21\n" - "ldr q6, [x22, #0x20]\n" - "ldr x20, [%x[args_ptr], %[offsetof_K]]\n" - "add x22, x22, #0x30\n" - "ldr q7, [x22, #0x0]\n" - "ldr q0, [%x[Apanel], #0x0]\n" - "ldr q1, [%x[Apanel], #0x10]\n" - "zip1 v8.2d, v4.2d, v4.2d\n" - "ldr q2, [%x[Apanel], #0x20]\n" - "zip2 v11.2d, v4.2d, v4.2d\n" - "ldr q4, [x22, #0x10]\n" - "zip1 v9.2d, v5.2d, v5.2d\n" - "zip2 v12.2d, v5.2d, v5.2d\n" - "cmp x20, #0x2\n" - "zip1 v10.2d, v6.2d, v6.2d\n" - "zip2 v13.2d, v6.2d, v6.2d\n" - "prfm pldl1keep, [%x[Apanel], #0x0]\n" - "mov v14.16b, v8.16b\n" - "mov v17.16b, v11.16b\n" - "prfm pldl1keep, [x22, #0x0]\n" - "mov v15.16b, v9.16b\n" - "mov v18.16b, v12.16b\n" - "prfm pldl1keep, [x22, #0x40]\n" - "mov v16.16b, v10.16b\n" - "mov v19.16b, v13.16b\n" - "prfm pldl1keep, [%x[Apanel], #0x40]\n" - "mov v20.16b, v8.16b\n" - "mov v21.16b, v9.16b\n" - "prfm pldl1keep, [x22, #0x80]\n" - "mov v22.16b, v10.16b\n" - "mov v23.16b, v11.16b\n" - "prfm pldl1keep, [%x[Apanel], #0x80]\n" - "mov v24.16b, v12.16b\n" - "mov v25.16b, v13.16b\n" - "prfm pldl1keep, [x22, #0xc0]\n" - "mov v26.16b, v8.16b\n" - "mov v27.16b, v9.16b\n" - "prfm pldl1keep, [x22, #0x100]\n" - "mov v28.16b, v10.16b\n" - "mov v29.16b, v11.16b\n" - "prfm pldl1keep, [%x[Apanel], #0xc0]\n" - "mov v30.16b, v12.16b\n" - "mov v31.16b, v13.16b\n" - "prfm pldl1keep, [x22, #0x140]\n" - "add x22, x22, #0x20\n" - "add %x[Apanel], %x[Apanel], #0x30\n" - "blt 5f\n" - "4:" // main loop head - "ldr q3, [%x[Apanel], #0x0]\n" - "ldr q5, [x22, #0x0]\n" - ".inst 0x6e47ec08 // bfmmla v8.4s, v0.8h, v7.8h\n" - "ldr q6, [x22, #0x10]\n" - ".inst 0x6e44ec0b // bfmmla v11.4s, v0.8h, v4.8h\n" - ".inst 0x6e47ec2e // bfmmla v14.4s, v1.8h, v7.8h\n" - ".inst 0x6e44ec31 // bfmmla v17.4s, v1.8h, v4.8h\n" - ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n" - "sub x20, x20, #0x2\n" - ".inst 0x6e44ec57 // bfmmla v23.4s, v2.8h, v4.8h\n" - ".inst 0x6e47ec7a // bfmmla v26.4s, v3.8h, v7.8h\n" - "ldr q7, [x22, #0x20]\n" - ".inst 0x6e44ec7d // bfmmla v29.4s, v3.8h, v4.8h\n" - "ldr q4, [x22, #0x30]\n" - ".inst 0x6e45ec09 // bfmmla v9.4s, v0.8h, v5.8h\n" - ".inst 0x6e46ec0c // bfmmla v12.4s, v0.8h, v6.8h\n" - ".inst 0x6e45ec2f // bfmmla v15.4s, v1.8h, v5.8h\n" - "cmp x20, #0x2\n" - ".inst 0x6e46ec32 // bfmmla v18.4s, v1.8h, v6.8h\n" - ".inst 0x6e45ec55 // bfmmla v21.4s, v2.8h, v5.8h\n" - "prfm pldl1keep, [%x[Apanel], #0x100]\n" - ".inst 0x6e46ec58 // bfmmla v24.4s, v2.8h, v6.8h\n" - ".inst 0x6e45ec7b // bfmmla v27.4s, v3.8h, v5.8h\n" - "ldr q5, [x22, #0x40]\n" - ".inst 0x6e46ec7e // bfmmla v30.4s, v3.8h, v6.8h\n" - "ldr q6, [x22, #0x50]\n" - ".inst 0x6e47ec0a // bfmmla v10.4s, v0.8h, v7.8h\n" - ".inst 0x6e44ec0d // bfmmla v13.4s, v0.8h, v4.8h\n" - "ldr q0, [%x[Apanel], #0x10]\n" - ".inst 0x6e47ec30 // bfmmla v16.4s, v1.8h, v7.8h\n" - ".inst 0x6e44ec33 // bfmmla v19.4s, v1.8h, v4.8h\n" - "ldr q1, [%x[Apanel], #0x20]\n" - ".inst 0x6e47ec56 // bfmmla v22.4s, v2.8h, v7.8h\n" - ".inst 0x6e44ec59 // bfmmla v25.4s, v2.8h, v4.8h\n" - "ldr q2, [%x[Apanel], #0x30]\n" - ".inst 0x6e47ec7c // bfmmla v28.4s, v3.8h, v7.8h\n" - "ldr q7, [x22, #0x60]\n" - ".inst 0x6e44ec7f // bfmmla v31.4s, v3.8h, v4.8h\n" - "ldr q3, [%x[Apanel], #0x40]\n" - "ldr q4, [x22, #0x70]\n" - ".inst 0x6e45ec08 // bfmmla v8.4s, v0.8h, v5.8h\n" - ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n" - ".inst 0x6e45ec2e // bfmmla v14.4s, v1.8h, v5.8h\n" - ".inst 0x6e46ec31 // bfmmla v17.4s, v1.8h, v6.8h\n" - "prfm pldl1keep, [x22, #0x180]\n" - ".inst 0x6e45ec54 // bfmmla v20.4s, v2.8h, v5.8h\n" - ".inst 0x6e46ec57 // bfmmla v23.4s, v2.8h, v6.8h\n" - "prfm pldl1keep, [x22, #0x1c0]\n" - ".inst 0x6e45ec7a // bfmmla v26.4s, v3.8h, v5.8h\n" - "ldr q5, [x22, #0x80]\n" - ".inst 0x6e46ec7d // bfmmla v29.4s, v3.8h, v6.8h\n" - "ldr q6, [x22, #0x90]\n" - "prfm pldl1keep, [%x[Apanel], #0x140]\n" - ".inst 0x6e47ec09 // bfmmla v9.4s, v0.8h, v7.8h\n" - "prfm pldl1keep, [x22, #0x200]\n" - ".inst 0x6e44ec0c // bfmmla v12.4s, v0.8h, v4.8h\n" - ".inst 0x6e47ec2f // bfmmla v15.4s, v1.8h, v7.8h\n" - ".inst 0x6e44ec32 // bfmmla v18.4s, v1.8h, v4.8h\n" - ".inst 0x6e47ec55 // bfmmla v21.4s, v2.8h, v7.8h\n" - ".inst 0x6e44ec58 // bfmmla v24.4s, v2.8h, v4.8h\n" - ".inst 0x6e47ec7b // bfmmla v27.4s, v3.8h, v7.8h\n" - "ldr q7, [x22, #0xa0]\n" - ".inst 0x6e44ec7e // bfmmla v30.4s, v3.8h, v4.8h\n" - "ldr q4, [x22, #0xb0]\n" - ".inst 0x6e45ec0a // bfmmla v10.4s, v0.8h, v5.8h\n" - ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" - "ldr q0, [%x[Apanel], #0x50]\n" - ".inst 0x6e45ec30 // bfmmla v16.4s, v1.8h, v5.8h\n" - ".inst 0x6e46ec33 // bfmmla v19.4s, v1.8h, v6.8h\n" - "ldr q1, [%x[Apanel], #0x60]\n" - ".inst 0x6e45ec56 // bfmmla v22.4s, v2.8h, v5.8h\n" - ".inst 0x6e46ec59 // bfmmla v25.4s, v2.8h, v6.8h\n" - "ldr q2, [%x[Apanel], #0x70]\n" - ".inst 0x6e45ec7c // bfmmla v28.4s, v3.8h, v5.8h\n" - ".inst 0x6e46ec7f // bfmmla v31.4s, v3.8h, v6.8h\n" - "add %x[Apanel], %x[Apanel], #0x80\n" - "add x22, x22, #0xc0\n" - "bge 4b\n" - "5:" // main loop skip - "ldr q3, [%x[Apanel], #0x0]\n" - "ldr q5, [x22, #0x0]\n" - ".inst 0x6e47ec08 // bfmmla v8.4s, v0.8h, v7.8h\n" - "ldr q6, [x22, #0x10]\n" - ".inst 0x6e44ec0b // bfmmla v11.4s, v0.8h, v4.8h\n" - ".inst 0x6e47ec2e // bfmmla v14.4s, v1.8h, v7.8h\n" - ".inst 0x6e44ec31 // bfmmla v17.4s, v1.8h, v4.8h\n" - ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n" - "add %x[Apanel], %x[Apanel], #0x10\n" - ".inst 0x6e44ec57 // bfmmla v23.4s, v2.8h, v4.8h\n" - ".inst 0x6e47ec7a // bfmmla v26.4s, v3.8h, v7.8h\n" - "ldr q7, [x22, #0x20]\n" - ".inst 0x6e44ec7d // bfmmla v29.4s, v3.8h, v4.8h\n" - "ldr q4, [x22, #0x30]\n" - ".inst 0x6e45ec09 // bfmmla v9.4s, v0.8h, v5.8h\n" - ".inst 0x6e46ec0c // bfmmla v12.4s, v0.8h, v6.8h\n" - ".inst 0x6e45ec2f // bfmmla v15.4s, v1.8h, v5.8h\n" - "add x22, x22, #0x40\n" - ".inst 0x6e46ec32 // bfmmla v18.4s, v1.8h, v6.8h\n" - ".inst 0x6e45ec55 // bfmmla v21.4s, v2.8h, v5.8h\n" - ".inst 0x6e46ec58 // bfmmla v24.4s, v2.8h, v6.8h\n" - ".inst 0x6e45ec7b // bfmmla v27.4s, v3.8h, v5.8h\n" - ".inst 0x6e46ec7e // bfmmla v30.4s, v3.8h, v6.8h\n" - ".inst 0x6e47ec0a // bfmmla v10.4s, v0.8h, v7.8h\n" - ".inst 0x6e44ec0d // bfmmla v13.4s, v0.8h, v4.8h\n" - ".inst 0x6e47ec30 // bfmmla v16.4s, v1.8h, v7.8h\n" - ".inst 0x6e44ec33 // bfmmla v19.4s, v1.8h, v4.8h\n" - ".inst 0x6e47ec56 // bfmmla v22.4s, v2.8h, v7.8h\n" - ".inst 0x6e44ec59 // bfmmla v25.4s, v2.8h, v4.8h\n" - ".inst 0x6e47ec7c // bfmmla v28.4s, v3.8h, v7.8h\n" - ".inst 0x6e44ec7f // bfmmla v31.4s, v3.8h, v4.8h\n" - "cbz x20, 6f\n" - "ldr q5, [x22, #0x0]\n" - "ldr q0, [%x[Apanel], #0x0]\n" - "ldr q1, [%x[Apanel], #0x10]\n" - "ldr q6, [x22, #0x10]\n" - "ldr q2, [%x[Apanel], #0x20]\n" - "ldr q3, [%x[Apanel], #0x30]\n" - "add %x[Apanel], %x[Apanel], #0x40\n" - "ldr q7, [x22, #0x20]\n" - "ldr q4, [x22, #0x30]\n" - ".inst 0x6e45ec08 // bfmmla v8.4s, v0.8h, v5.8h\n" - ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n" - ".inst 0x6e45ec2e // bfmmla v14.4s, v1.8h, v5.8h\n" - ".inst 0x6e46ec31 // bfmmla v17.4s, v1.8h, v6.8h\n" - ".inst 0x6e45ec54 // bfmmla v20.4s, v2.8h, v5.8h\n" - ".inst 0x6e46ec57 // bfmmla v23.4s, v2.8h, v6.8h\n" - ".inst 0x6e45ec7a // bfmmla v26.4s, v3.8h, v5.8h\n" - "ldr q5, [x22, #0x40]\n" - ".inst 0x6e46ec7d // bfmmla v29.4s, v3.8h, v6.8h\n" - "ldr q6, [x22, #0x50]\n" - ".inst 0x6e47ec09 // bfmmla v9.4s, v0.8h, v7.8h\n" - ".inst 0x6e44ec0c // bfmmla v12.4s, v0.8h, v4.8h\n" - ".inst 0x6e47ec2f // bfmmla v15.4s, v1.8h, v7.8h\n" - "add x22, x22, #0x60\n" - ".inst 0x6e44ec32 // bfmmla v18.4s, v1.8h, v4.8h\n" - ".inst 0x6e47ec55 // bfmmla v21.4s, v2.8h, v7.8h\n" - ".inst 0x6e44ec58 // bfmmla v24.4s, v2.8h, v4.8h\n" - ".inst 0x6e47ec7b // bfmmla v27.4s, v3.8h, v7.8h\n" - ".inst 0x6e44ec7e // bfmmla v30.4s, v3.8h, v4.8h\n" - ".inst 0x6e45ec0a // bfmmla v10.4s, v0.8h, v5.8h\n" - ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" - ".inst 0x6e45ec30 // bfmmla v16.4s, v1.8h, v5.8h\n" - ".inst 0x6e46ec33 // bfmmla v19.4s, v1.8h, v6.8h\n" - ".inst 0x6e45ec56 // bfmmla v22.4s, v2.8h, v5.8h\n" - ".inst 0x6e46ec59 // bfmmla v25.4s, v2.8h, v6.8h\n" - ".inst 0x6e45ec7c // bfmmla v28.4s, v3.8h, v5.8h\n" - ".inst 0x6e46ec7f // bfmmla v31.4s, v3.8h, v6.8h\n" - "6:" // multiply loop done - "add x20, %x[args_ptr], %[offset_max]\n" - "uzp1 v7.2d, v8.2d, v11.2d\n" - "uzp2 v8.2d, v8.2d, v11.2d\n" - "ld1r { v1.4s }, [x20]\n" - "uzp1 v11.2d, v9.2d, v12.2d\n" - "uzp2 v9.2d, v9.2d, v12.2d\n" - "uzp1 v12.2d, v10.2d, v13.2d\n" - "uzp2 v10.2d, v10.2d, v13.2d\n" - "add x20, %x[args_ptr], %[offset_min]\n" - "ld1r { v0.4s }, [x20]\n" - "uzp1 v13.2d, v14.2d, v17.2d\n" - "uzp2 v14.2d, v14.2d, v17.2d\n" - "uzp1 v17.2d, v15.2d, v18.2d\n" - "uzp2 v15.2d, v15.2d, v18.2d\n" - "cmp x23, #0xc\n" - "uzp1 v18.2d, v16.2d, v19.2d\n" - "uzp2 v16.2d, v16.2d, v19.2d\n" - "uzp1 v19.2d, v20.2d, v23.2d\n" - "uzp2 v20.2d, v20.2d, v23.2d\n" - "uzp1 v23.2d, v21.2d, v24.2d\n" - "uzp2 v21.2d, v21.2d, v24.2d\n" - "uzp1 v24.2d, v22.2d, v25.2d\n" - "uzp2 v22.2d, v22.2d, v25.2d\n" - "uzp1 v25.2d, v26.2d, v29.2d\n" - "uzp2 v26.2d, v26.2d, v29.2d\n" - "uzp1 v29.2d, v27.2d, v30.2d\n" - "uzp2 v27.2d, v27.2d, v30.2d\n" - "uzp1 v30.2d, v28.2d, v31.2d\n" - "uzp2 v28.2d, v28.2d, v31.2d\n" - "fmin v7.4s, v7.4s, v1.4s\n" - "fmin v11.4s, v11.4s, v1.4s\n" - "fmin v12.4s, v12.4s, v1.4s\n" - "fmin v8.4s, v8.4s, v1.4s\n" - "fmin v9.4s, v9.4s, v1.4s\n" - "fmin v10.4s, v10.4s, v1.4s\n" - "fmin v13.4s, v13.4s, v1.4s\n" - "fmin v17.4s, v17.4s, v1.4s\n" - "fmin v18.4s, v18.4s, v1.4s\n" - "fmin v14.4s, v14.4s, v1.4s\n" - "fmin v15.4s, v15.4s, v1.4s\n" - "fmin v16.4s, v16.4s, v1.4s\n" - "fmin v19.4s, v19.4s, v1.4s\n" - "fmin v23.4s, v23.4s, v1.4s\n" - "fmin v24.4s, v24.4s, v1.4s\n" - "fmin v20.4s, v20.4s, v1.4s\n" - "fmin v21.4s, v21.4s, v1.4s\n" - "fmin v22.4s, v22.4s, v1.4s\n" - "fmin v25.4s, v25.4s, v1.4s\n" - "fmin v29.4s, v29.4s, v1.4s\n" - "fmin v30.4s, v30.4s, v1.4s\n" - "fmin v26.4s, v26.4s, v1.4s\n" - "fmin v27.4s, v27.4s, v1.4s\n" - "fmin v28.4s, v28.4s, v1.4s\n" - "fmax v7.4s, v7.4s, v0.4s\n" - "fmax v11.4s, v11.4s, v0.4s\n" - "fmax v12.4s, v12.4s, v0.4s\n" - "fmax v8.4s, v8.4s, v0.4s\n" - "fmax v9.4s, v9.4s, v0.4s\n" - "fmax v10.4s, v10.4s, v0.4s\n" - "fmax v13.4s, v13.4s, v0.4s\n" - "fmax v17.4s, v17.4s, v0.4s\n" - "fmax v18.4s, v18.4s, v0.4s\n" - "fmax v14.4s, v14.4s, v0.4s\n" - "fmax v15.4s, v15.4s, v0.4s\n" - "fmax v16.4s, v16.4s, v0.4s\n" - "fmax v19.4s, v19.4s, v0.4s\n" - "fmax v23.4s, v23.4s, v0.4s\n" - "fmax v24.4s, v24.4s, v0.4s\n" - "fmax v20.4s, v20.4s, v0.4s\n" - "fmax v21.4s, v21.4s, v0.4s\n" - "fmax v22.4s, v22.4s, v0.4s\n" - "fmax v25.4s, v25.4s, v0.4s\n" - "fmax v29.4s, v29.4s, v0.4s\n" - "fmax v30.4s, v30.4s, v0.4s\n" - "fmax v26.4s, v26.4s, v0.4s\n" - "fmax v27.4s, v27.4s, v0.4s\n" - "fmax v28.4s, v28.4s, v0.4s\n" - "blt 7f\n" - "str q26, [x24, #0x0]\n" - "str q27, [x24, #0x10]\n" - "str q28, [x24, #0x20]\n" - "add x24, x24, #0x30\n" - "str q25, [x9, #0x0]\n" - "str q29, [x9, #0x10]\n" - "str q30, [x9, #0x20]\n" - "add x9, x9, #0x30\n" - "str q20, [x25, #0x0]\n" - "str q21, [x25, #0x10]\n" - "str q22, [x25, #0x20]\n" - "add x25, x25, #0x30\n" - "str q19, [x11, #0x0]\n" - "str q23, [x11, #0x10]\n" - "str q24, [x11, #0x20]\n" - "add x11, x11, #0x30\n" - "str q14, [x26, #0x0]\n" - "str q15, [x26, #0x10]\n" - "str q16, [x26, #0x20]\n" - "add x26, x26, #0x30\n" - "str q13, [x10, #0x0]\n" - "str q17, [x10, #0x10]\n" - "str q18, [x10, #0x20]\n" - "add x10, x10, #0x30\n" - "str q8, [x27, #0x0]\n" - "str q9, [x27, #0x10]\n" - "str q10, [x27, #0x20]\n" - "add x27, x27, #0x30\n" - "str q7, [%x[Cpanel], #0x0]\n" - "str q11, [%x[Cpanel], #0x10]\n" - "str q12, [%x[Cpanel], #0x20]\n" - "add %x[Cpanel], %x[Cpanel], #0x30\n" - "b 14f\n" - "7:" // partial output - "tbz x23, #3, 9f\n" - "st1 { v26.4s }, [x24], #0x10\n" - "st1 { v27.4s }, [x24], #0x10\n" - "st1 { v25.4s }, [x9], #0x10\n" - "st1 { v29.4s }, [x9], #0x10\n" - "st1 { v20.4s }, [x25], #0x10\n" - "st1 { v21.4s }, [x25], #0x10\n" - "st1 { v19.4s }, [x11], #0x10\n" - "st1 { v23.4s }, [x11], #0x10\n" - "st1 { v14.4s }, [x26], #0x10\n" - "st1 { v15.4s }, [x26], #0x10\n" - "st1 { v13.4s }, [x10], #0x10\n" - "st1 { v17.4s }, [x10], #0x10\n" - "st1 { v8.4s }, [x27], #0x10\n" - "st1 { v9.4s }, [x27], #0x10\n" - "st1 { v7.4s }, [%x[Cpanel]], #0x10\n" - "st1 { v11.4s }, [%x[Cpanel]], #0x10\n" - "tbz x23, #1, 8f\n" - "str d28, [x24], #0x8\n" - "str d30, [x9], #0x8\n" - "str d22, [x25], #0x8\n" - "str d24, [x11], #0x8\n" - "str d16, [x26], #0x8\n" - "str d18, [x10], #0x8\n" - "str d10, [x27], #0x8\n" - "str d12, [%x[Cpanel]], #0x8\n" - "tbz x23, #0, 13f\n" - "st1 { v28.s }[2], [x24]\n" - "st1 { v30.s }[2], [x9]\n" - "st1 { v22.s }[2], [x25]\n" - "st1 { v24.s }[2], [x11]\n" - "st1 { v16.s }[2], [x26]\n" - "st1 { v18.s }[2], [x10]\n" - "st1 { v10.s }[2], [x27]\n" - "st1 { v12.s }[2], [%x[Cpanel]]\n" - "b 13f\n" - "8:" // partial result store: partial_1_8 - "tbz x23, #0, 13f\n" - "str s28, [x24, #0x0]\n" - "str s30, [x9, #0x0]\n" - "str s22, [x25, #0x0]\n" - "str s24, [x11, #0x0]\n" - "str s16, [x26, #0x0]\n" - "str s18, [x10, #0x0]\n" - "str s10, [x27, #0x0]\n" - "str s12, [%x[Cpanel], #0x0]\n" - "b 13f\n" - "9:" // partial result store: partial_4_0 - "tbz x23, #2, 11f\n" - "st1 { v26.4s }, [x24], #0x10\n" - "st1 { v25.4s }, [x9], #0x10\n" - "st1 { v20.4s }, [x25], #0x10\n" - "st1 { v19.4s }, [x11], #0x10\n" - "st1 { v14.4s }, [x26], #0x10\n" - "st1 { v13.4s }, [x10], #0x10\n" - "st1 { v8.4s }, [x27], #0x10\n" - "st1 { v7.4s }, [%x[Cpanel]], #0x10\n" - "tbz x23, #1, 10f\n" - "str d27, [x24], #0x8\n" - "str d29, [x9], #0x8\n" - "str d21, [x25], #0x8\n" - "str d23, [x11], #0x8\n" - "str d15, [x26], #0x8\n" - "str d17, [x10], #0x8\n" - "str d9, [x27], #0x8\n" - "str d11, [%x[Cpanel]], #0x8\n" - "tbz x23, #0, 13f\n" - "st1 { v27.s }[2], [x24]\n" - "st1 { v29.s }[2], [x9]\n" - "st1 { v21.s }[2], [x25]\n" - "st1 { v23.s }[2], [x11]\n" - "st1 { v15.s }[2], [x26]\n" - "st1 { v17.s }[2], [x10]\n" - "st1 { v9.s }[2], [x27]\n" - "st1 { v11.s }[2], [%x[Cpanel]]\n" - "b 13f\n" - "10:" // partial result store: partial_1_4 - "tbz x23, #0, 13f\n" - "str s27, [x24, #0x0]\n" - "str s29, [x9, #0x0]\n" - "str s21, [x25, #0x0]\n" - "str s23, [x11, #0x0]\n" - "str s15, [x26, #0x0]\n" - "str s17, [x10, #0x0]\n" - "str s9, [x27, #0x0]\n" - "str s11, [%x[Cpanel], #0x0]\n" - "b 13f\n" - "11:" // partial result store: partial_2_0 - "tbz x23, #1, 12f\n" - "str d26, [x24], #0x8\n" - "str d25, [x9], #0x8\n" - "str d20, [x25], #0x8\n" - "str d19, [x11], #0x8\n" - "str d14, [x26], #0x8\n" - "str d13, [x10], #0x8\n" - "str d8, [x27], #0x8\n" - "str d7, [%x[Cpanel]], #0x8\n" - "tbz x23, #0, 13f\n" - "st1 { v26.s }[2], [x24]\n" - "st1 { v25.s }[2], [x9]\n" - "st1 { v20.s }[2], [x25]\n" - "st1 { v19.s }[2], [x11]\n" - "st1 { v14.s }[2], [x26]\n" - "st1 { v13.s }[2], [x10]\n" - "st1 { v8.s }[2], [x27]\n" - "st1 { v7.s }[2], [%x[Cpanel]]\n" - "b 13f\n" - "12:" // partial result store: partial_1_0 - "str s26, [x24, #0x0]\n" - "str s25, [x9, #0x0]\n" - "str s20, [x25, #0x0]\n" - "str s19, [x11, #0x0]\n" - "str s14, [x26, #0x0]\n" - "str s13, [x10, #0x0]\n" - "str s8, [x27, #0x0]\n" - "str s7, [%x[Cpanel], #0x0]\n" - "13:" // partial result store: Done - "14:" // store done - "subs x23, x23, #0xc\n" - "bgt 3b\n" - "subs %x[M], %x[M], #0x8\n" - "mov %x[Cpanel], x28\n" - "bgt 1b\n" - : [Apanel] "+&r" (Apanel), [Cpanel] "+&r" (Cpanel), [M] "+&r" (M) - : [args_ptr] "r" (&ka), [ldc] "r" (ldc * sizeof(float)), [offset_max] "I" (offsetof(KernelArgs, maxval)), [offset_min] "I" (offsetof(KernelArgs, minval)), [offsetof_Bpanel] "I" (offsetof(KernelArgs, Bpanel)), [offsetof_K] "I" (offsetof(KernelArgs, K)), [offsetof_N] "I" (offsetof(KernelArgs, N)) - : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x11", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" - ); -} - -#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_bf16_bf16_f32p/matmul_clamp_bf16_bf16_f32p_interface.h b/kai/ukernels/matmul/matmul_clamp_bf16_bf16_f32p/matmul_clamp_bf16_bf16_f32p_interface.h deleted file mode 100644 index 28a715b4..00000000 --- a/kai/ukernels/matmul/matmul_clamp_bf16_bf16_f32p/matmul_clamp_bf16_bf16_f32p_interface.h +++ /dev/null @@ -1,57 +0,0 @@ -// -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates -// -// SPDX-License-Identifier: Apache-2.0 -// -#pragma once - -#if !defined(__aarch64__) || !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) -#error This file must be compiled for AArch64, FEAT_BF16. -#else // Architectural features check. - -#include - -#ifdef __cplusplus -extern "C" { -#endif - -// All micro-kernels variants of the same type share the same interfaces -// In this case, the micro-kernel type is: matmul_clamp_bf16_bf16_f32p - -/// Micro-kernel helper functions ("get" methods) -typedef size_t (*kai_matmul_clamp_bf16_bf16_f32p_get_m_step_func_t)(void); -typedef size_t (*kai_matmul_clamp_bf16_bf16_f32p_get_n_step_func_t)(void); -typedef size_t (*kai_matmul_clamp_bf16_bf16_f32p_get_mr_func_t)(void); -typedef size_t (*kai_matmul_clamp_bf16_bf16_f32p_get_nr_func_t)(void); -typedef size_t (*kai_matmul_clamp_bf16_bf16_f32p_get_kr_func_t)(void); -typedef size_t (*kai_matmul_clamp_bf16_bf16_f32p_get_sr_func_t)(void); -typedef size_t (*kai_matmul_clamp_bf16_bf16_f32p_get_lhs_offset_func_t)(size_t m_idx, size_t lhs_stride); -typedef size_t (*kai_matmul_clamp_bf16_bf16_f32p_get_rhs_packed_offset_func_t)(size_t n_idx, size_t k); -typedef size_t (*kai_matmul_clamp_bf16_bf16_f32p_get_dst_offset_func_t)(size_t m_idx, size_t n_idx, size_t dst_stride); -typedef size_t (*kai_matmul_clamp_bf16_bf16_f32p_get_dst_size_func_t)(size_t m, size_t n); - -/// Micro-kernel core function ("run" method) -typedef void (*kai_matmul_clamp_bf16_bf16_f32p_run_matmul_func_t)( - size_t m, size_t n, size_t k, const void* lhs, size_t lhs_stride, const void* rhs_packed, void* dst, - size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max); - -/// Micro-kernel interface -struct kai_matmul_clamp_bf16_bf16_f32p_ukernel { - kai_matmul_clamp_bf16_bf16_f32p_get_m_step_func_t get_m_step; - kai_matmul_clamp_bf16_bf16_f32p_get_n_step_func_t get_n_step; - kai_matmul_clamp_bf16_bf16_f32p_get_mr_func_t get_mr; - kai_matmul_clamp_bf16_bf16_f32p_get_nr_func_t get_nr; - kai_matmul_clamp_bf16_bf16_f32p_get_nr_func_t get_kr; - kai_matmul_clamp_bf16_bf16_f32p_get_sr_func_t get_sr; - kai_matmul_clamp_bf16_bf16_f32p_get_lhs_offset_func_t get_lhs_packed_offset; - kai_matmul_clamp_bf16_bf16_f32p_get_rhs_packed_offset_func_t get_rhs_packed_offset; - kai_matmul_clamp_bf16_bf16_f32p_get_dst_offset_func_t get_dst_offset; - kai_matmul_clamp_bf16_bf16_f32p_get_dst_size_func_t get_dst_size; - kai_matmul_clamp_bf16_bf16_f32p_run_matmul_func_t run_matmul; -}; - -#ifdef __cplusplus -} -#endif - -#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.c b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.c new file mode 100644 index 00000000..13ebc4e3 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.c @@ -0,0 +1,583 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) +#error This file must be compiled for AArch64, FEAT_BF16. +#else // Architectural features check. + +#include +#include +#include +#include +#include + +typedef bfloat16_t bfloat16; + +#include "kai/kai_common.h" +#include "kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.h" + +static const size_t kai_mr = 8; +static const size_t kai_nr = 12; +static const size_t kai_kr = 4; +static const size_t kai_sr = 1; + +size_t kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(void) { + return kai_mr; +} + +size_t kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(void) { + return kai_nr; +} + +size_t kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(void) { + return kai_sr; +} + +size_t kai_get_lhs_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(size_t m_idx, size_t stride) { + KAI_ASSUME(m_idx % kai_mr == 0); + + return m_idx * stride; +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(size_t n_idx, size_t k) { + KAI_ASSUME(n_idx % kai_nr == 0); + + // return n_idx / kai_nr * (kai_nr * sizeof(float) + kai_nr * k * sizeof(bfloat16)); + return n_idx * (sizeof(float) + kai_roundup(k, kai_kr) * sizeof(bfloat16_t)); +} + +size_t kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla( + size_t m_idx, size_t n_idx, size_t stride) { + KAI_ASSUME(m_idx % kai_mr == 0); + KAI_ASSUME(n_idx % kai_nr == 0); + + return m_idx * stride + n_idx * sizeof(float); +} + +size_t kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(size_t m, size_t n) { + return m * n * sizeof(float); +} + +void kai_run_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla( + size_t m, size_t n, size_t k, // + const void* lhs_packed, // + const void* rhs_packed, // + void* dst, size_t dst_stride_row, size_t dst_stride_col, // + float clamp_min, float clamp_max) { + KAI_ASSERT(dst_stride_col == sizeof(float)); + + const void* Apanel = lhs_packed; + // const void *Bpanel = rhs_packed; + void* Cpanel = dst; + size_t ldc = dst_stride_row / sizeof(float); + + size_t M = m; + + typedef struct { + float maxval; + float minval; + size_t N; + size_t K; + const void* Bpanel; + void* output_ptr; + } KernelArgs; + + KernelArgs ka; + + ka.N = n; + ka.K = kai_roundup(k, 4) / 4 - 1; + + ka.Bpanel = rhs_packed; + + // Direct output. + ka.output_ptr = dst; + + // Clamping output. + ka.maxval = clamp_max; + ka.minval = clamp_min; + + __asm__ __volatile__( + "1:" // Height loop + "add x11, %x[Cpanel], %x[ldc], LSL #2\n" + "add x10, %x[Cpanel], %x[ldc], LSL #1\n" + "add x9, x11, %x[ldc], LSL #1\n" + "cmp %x[M], #0x8\n" + "add x28, %x[Cpanel], %x[ldc], LSL #3\n" + "add x27, %x[Cpanel], %x[ldc]\n" + "add x26, x10, %x[ldc]\n" + "add x25, x11, %x[ldc]\n" + "add x24, x9, %x[ldc]\n" + "bge 2f\n" + "cmp %x[M], #0x2\n" + "mov x24, %x[Cpanel]\n" + "csel x27, x27, %x[Cpanel], GE\n" + "csel x10, x10, %x[Cpanel], GT\n" + "cmp %x[M], #0x4\n" + "csel x26, x26, %x[Cpanel], GE\n" + "csel x11, x11, %x[Cpanel], GT\n" + "cmp %x[M], #0x6\n" + "csel x25, x25, %x[Cpanel], GE\n" + "csel x9, x9, %x[Cpanel], GT\n" + "2:" // all rows valid + "ldr x23, [%x[args_ptr], %[offsetof_N]]\n" + "ldr x22, [%x[args_ptr], %[offsetof_Bpanel]]\n" + "mov x21, %x[Apanel]\n" + "3:" // Width loop + "ldr q4, [x22, #0x0]\n" + "ldr q5, [x22, #0x10]\n" + "mov %x[Apanel], x21\n" + "ldr q6, [x22, #0x20]\n" + "ldr x20, [%x[args_ptr], %[offsetof_K]]\n" + "add x22, x22, #0x30\n" + "ldr q7, [x22, #0x0]\n" + "ldr q0, [%x[Apanel], #0x0]\n" + "ldr q1, [%x[Apanel], #0x10]\n" + "zip1 v8.2d, v4.2d, v4.2d\n" + "ldr q2, [%x[Apanel], #0x20]\n" + "zip2 v11.2d, v4.2d, v4.2d\n" + "ldr q4, [x22, #0x10]\n" + "zip1 v9.2d, v5.2d, v5.2d\n" + "zip2 v12.2d, v5.2d, v5.2d\n" + "cmp x20, #0x2\n" + "zip1 v10.2d, v6.2d, v6.2d\n" + "zip2 v13.2d, v6.2d, v6.2d\n" + "prfm pldl1keep, [%x[Apanel], #0x0]\n" + "mov v14.16b, v8.16b\n" + "mov v17.16b, v11.16b\n" + "prfm pldl1keep, [x22, #0x0]\n" + "mov v15.16b, v9.16b\n" + "mov v18.16b, v12.16b\n" + "prfm pldl1keep, [x22, #0x40]\n" + "mov v16.16b, v10.16b\n" + "mov v19.16b, v13.16b\n" + "prfm pldl1keep, [%x[Apanel], #0x40]\n" + "mov v20.16b, v8.16b\n" + "mov v21.16b, v9.16b\n" + "prfm pldl1keep, [x22, #0x80]\n" + "mov v22.16b, v10.16b\n" + "mov v23.16b, v11.16b\n" + "prfm pldl1keep, [%x[Apanel], #0x80]\n" + "mov v24.16b, v12.16b\n" + "mov v25.16b, v13.16b\n" + "prfm pldl1keep, [x22, #0xc0]\n" + "mov v26.16b, v8.16b\n" + "mov v27.16b, v9.16b\n" + "prfm pldl1keep, [x22, #0x100]\n" + "mov v28.16b, v10.16b\n" + "mov v29.16b, v11.16b\n" + "prfm pldl1keep, [%x[Apanel], #0xc0]\n" + "mov v30.16b, v12.16b\n" + "mov v31.16b, v13.16b\n" + "prfm pldl1keep, [x22, #0x140]\n" + "add x22, x22, #0x20\n" + "add %x[Apanel], %x[Apanel], #0x30\n" + "blt 5f\n" + "4:" // main loop head + "ldr q3, [%x[Apanel], #0x0]\n" + "ldr q5, [x22, #0x0]\n" + ".inst 0x6e47ec08 // bfmmla v8.4s, v0.8h, v7.8h\n" + "ldr q6, [x22, #0x10]\n" + ".inst 0x6e44ec0b // bfmmla v11.4s, v0.8h, v4.8h\n" + ".inst 0x6e47ec2e // bfmmla v14.4s, v1.8h, v7.8h\n" + ".inst 0x6e44ec31 // bfmmla v17.4s, v1.8h, v4.8h\n" + ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n" + "sub x20, x20, #0x2\n" + ".inst 0x6e44ec57 // bfmmla v23.4s, v2.8h, v4.8h\n" + ".inst 0x6e47ec7a // bfmmla v26.4s, v3.8h, v7.8h\n" + "ldr q7, [x22, #0x20]\n" + ".inst 0x6e44ec7d // bfmmla v29.4s, v3.8h, v4.8h\n" + "ldr q4, [x22, #0x30]\n" + ".inst 0x6e45ec09 // bfmmla v9.4s, v0.8h, v5.8h\n" + ".inst 0x6e46ec0c // bfmmla v12.4s, v0.8h, v6.8h\n" + ".inst 0x6e45ec2f // bfmmla v15.4s, v1.8h, v5.8h\n" + "cmp x20, #0x2\n" + ".inst 0x6e46ec32 // bfmmla v18.4s, v1.8h, v6.8h\n" + ".inst 0x6e45ec55 // bfmmla v21.4s, v2.8h, v5.8h\n" + "prfm pldl1keep, [%x[Apanel], #0x100]\n" + ".inst 0x6e46ec58 // bfmmla v24.4s, v2.8h, v6.8h\n" + ".inst 0x6e45ec7b // bfmmla v27.4s, v3.8h, v5.8h\n" + "ldr q5, [x22, #0x40]\n" + ".inst 0x6e46ec7e // bfmmla v30.4s, v3.8h, v6.8h\n" + "ldr q6, [x22, #0x50]\n" + ".inst 0x6e47ec0a // bfmmla v10.4s, v0.8h, v7.8h\n" + ".inst 0x6e44ec0d // bfmmla v13.4s, v0.8h, v4.8h\n" + "ldr q0, [%x[Apanel], #0x10]\n" + ".inst 0x6e47ec30 // bfmmla v16.4s, v1.8h, v7.8h\n" + ".inst 0x6e44ec33 // bfmmla v19.4s, v1.8h, v4.8h\n" + "ldr q1, [%x[Apanel], #0x20]\n" + ".inst 0x6e47ec56 // bfmmla v22.4s, v2.8h, v7.8h\n" + ".inst 0x6e44ec59 // bfmmla v25.4s, v2.8h, v4.8h\n" + "ldr q2, [%x[Apanel], #0x30]\n" + ".inst 0x6e47ec7c // bfmmla v28.4s, v3.8h, v7.8h\n" + "ldr q7, [x22, #0x60]\n" + ".inst 0x6e44ec7f // bfmmla v31.4s, v3.8h, v4.8h\n" + "ldr q3, [%x[Apanel], #0x40]\n" + "ldr q4, [x22, #0x70]\n" + ".inst 0x6e45ec08 // bfmmla v8.4s, v0.8h, v5.8h\n" + ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n" + ".inst 0x6e45ec2e // bfmmla v14.4s, v1.8h, v5.8h\n" + ".inst 0x6e46ec31 // bfmmla v17.4s, v1.8h, v6.8h\n" + "prfm pldl1keep, [x22, #0x180]\n" + ".inst 0x6e45ec54 // bfmmla v20.4s, v2.8h, v5.8h\n" + ".inst 0x6e46ec57 // bfmmla v23.4s, v2.8h, v6.8h\n" + "prfm pldl1keep, [x22, #0x1c0]\n" + ".inst 0x6e45ec7a // bfmmla v26.4s, v3.8h, v5.8h\n" + "ldr q5, [x22, #0x80]\n" + ".inst 0x6e46ec7d // bfmmla v29.4s, v3.8h, v6.8h\n" + "ldr q6, [x22, #0x90]\n" + "prfm pldl1keep, [%x[Apanel], #0x140]\n" + ".inst 0x6e47ec09 // bfmmla v9.4s, v0.8h, v7.8h\n" + "prfm pldl1keep, [x22, #0x200]\n" + ".inst 0x6e44ec0c // bfmmla v12.4s, v0.8h, v4.8h\n" + ".inst 0x6e47ec2f // bfmmla v15.4s, v1.8h, v7.8h\n" + ".inst 0x6e44ec32 // bfmmla v18.4s, v1.8h, v4.8h\n" + ".inst 0x6e47ec55 // bfmmla v21.4s, v2.8h, v7.8h\n" + ".inst 0x6e44ec58 // bfmmla v24.4s, v2.8h, v4.8h\n" + ".inst 0x6e47ec7b // bfmmla v27.4s, v3.8h, v7.8h\n" + "ldr q7, [x22, #0xa0]\n" + ".inst 0x6e44ec7e // bfmmla v30.4s, v3.8h, v4.8h\n" + "ldr q4, [x22, #0xb0]\n" + ".inst 0x6e45ec0a // bfmmla v10.4s, v0.8h, v5.8h\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + "ldr q0, [%x[Apanel], #0x50]\n" + ".inst 0x6e45ec30 // bfmmla v16.4s, v1.8h, v5.8h\n" + ".inst 0x6e46ec33 // bfmmla v19.4s, v1.8h, v6.8h\n" + "ldr q1, [%x[Apanel], #0x60]\n" + ".inst 0x6e45ec56 // bfmmla v22.4s, v2.8h, v5.8h\n" + ".inst 0x6e46ec59 // bfmmla v25.4s, v2.8h, v6.8h\n" + "ldr q2, [%x[Apanel], #0x70]\n" + ".inst 0x6e45ec7c // bfmmla v28.4s, v3.8h, v5.8h\n" + ".inst 0x6e46ec7f // bfmmla v31.4s, v3.8h, v6.8h\n" + "add %x[Apanel], %x[Apanel], #0x80\n" + "add x22, x22, #0xc0\n" + "bge 4b\n" + "5:" // main loop skip + "ldr q3, [%x[Apanel], #0x0]\n" + "ldr q5, [x22, #0x0]\n" + ".inst 0x6e47ec08 // bfmmla v8.4s, v0.8h, v7.8h\n" + "ldr q6, [x22, #0x10]\n" + ".inst 0x6e44ec0b // bfmmla v11.4s, v0.8h, v4.8h\n" + ".inst 0x6e47ec2e // bfmmla v14.4s, v1.8h, v7.8h\n" + ".inst 0x6e44ec31 // bfmmla v17.4s, v1.8h, v4.8h\n" + ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n" + "add %x[Apanel], %x[Apanel], #0x10\n" + ".inst 0x6e44ec57 // bfmmla v23.4s, v2.8h, v4.8h\n" + ".inst 0x6e47ec7a // bfmmla v26.4s, v3.8h, v7.8h\n" + "ldr q7, [x22, #0x20]\n" + ".inst 0x6e44ec7d // bfmmla v29.4s, v3.8h, v4.8h\n" + "ldr q4, [x22, #0x30]\n" + ".inst 0x6e45ec09 // bfmmla v9.4s, v0.8h, v5.8h\n" + ".inst 0x6e46ec0c // bfmmla v12.4s, v0.8h, v6.8h\n" + ".inst 0x6e45ec2f // bfmmla v15.4s, v1.8h, v5.8h\n" + "add x22, x22, #0x40\n" + ".inst 0x6e46ec32 // bfmmla v18.4s, v1.8h, v6.8h\n" + ".inst 0x6e45ec55 // bfmmla v21.4s, v2.8h, v5.8h\n" + ".inst 0x6e46ec58 // bfmmla v24.4s, v2.8h, v6.8h\n" + ".inst 0x6e45ec7b // bfmmla v27.4s, v3.8h, v5.8h\n" + ".inst 0x6e46ec7e // bfmmla v30.4s, v3.8h, v6.8h\n" + ".inst 0x6e47ec0a // bfmmla v10.4s, v0.8h, v7.8h\n" + ".inst 0x6e44ec0d // bfmmla v13.4s, v0.8h, v4.8h\n" + ".inst 0x6e47ec30 // bfmmla v16.4s, v1.8h, v7.8h\n" + ".inst 0x6e44ec33 // bfmmla v19.4s, v1.8h, v4.8h\n" + ".inst 0x6e47ec56 // bfmmla v22.4s, v2.8h, v7.8h\n" + ".inst 0x6e44ec59 // bfmmla v25.4s, v2.8h, v4.8h\n" + ".inst 0x6e47ec7c // bfmmla v28.4s, v3.8h, v7.8h\n" + ".inst 0x6e44ec7f // bfmmla v31.4s, v3.8h, v4.8h\n" + "cbz x20, 6f\n" + "ldr q5, [x22, #0x0]\n" + "ldr q0, [%x[Apanel], #0x0]\n" + "ldr q1, [%x[Apanel], #0x10]\n" + "ldr q6, [x22, #0x10]\n" + "ldr q2, [%x[Apanel], #0x20]\n" + "ldr q3, [%x[Apanel], #0x30]\n" + "add %x[Apanel], %x[Apanel], #0x40\n" + "ldr q7, [x22, #0x20]\n" + "ldr q4, [x22, #0x30]\n" + ".inst 0x6e45ec08 // bfmmla v8.4s, v0.8h, v5.8h\n" + ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n" + ".inst 0x6e45ec2e // bfmmla v14.4s, v1.8h, v5.8h\n" + ".inst 0x6e46ec31 // bfmmla v17.4s, v1.8h, v6.8h\n" + ".inst 0x6e45ec54 // bfmmla v20.4s, v2.8h, v5.8h\n" + ".inst 0x6e46ec57 // bfmmla v23.4s, v2.8h, v6.8h\n" + ".inst 0x6e45ec7a // bfmmla v26.4s, v3.8h, v5.8h\n" + "ldr q5, [x22, #0x40]\n" + ".inst 0x6e46ec7d // bfmmla v29.4s, v3.8h, v6.8h\n" + "ldr q6, [x22, #0x50]\n" + ".inst 0x6e47ec09 // bfmmla v9.4s, v0.8h, v7.8h\n" + ".inst 0x6e44ec0c // bfmmla v12.4s, v0.8h, v4.8h\n" + ".inst 0x6e47ec2f // bfmmla v15.4s, v1.8h, v7.8h\n" + "add x22, x22, #0x60\n" + ".inst 0x6e44ec32 // bfmmla v18.4s, v1.8h, v4.8h\n" + ".inst 0x6e47ec55 // bfmmla v21.4s, v2.8h, v7.8h\n" + ".inst 0x6e44ec58 // bfmmla v24.4s, v2.8h, v4.8h\n" + ".inst 0x6e47ec7b // bfmmla v27.4s, v3.8h, v7.8h\n" + ".inst 0x6e44ec7e // bfmmla v30.4s, v3.8h, v4.8h\n" + ".inst 0x6e45ec0a // bfmmla v10.4s, v0.8h, v5.8h\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + ".inst 0x6e45ec30 // bfmmla v16.4s, v1.8h, v5.8h\n" + ".inst 0x6e46ec33 // bfmmla v19.4s, v1.8h, v6.8h\n" + ".inst 0x6e45ec56 // bfmmla v22.4s, v2.8h, v5.8h\n" + ".inst 0x6e46ec59 // bfmmla v25.4s, v2.8h, v6.8h\n" + ".inst 0x6e45ec7c // bfmmla v28.4s, v3.8h, v5.8h\n" + ".inst 0x6e46ec7f // bfmmla v31.4s, v3.8h, v6.8h\n" + "6:" // multiply loop done + "add x20, %x[args_ptr], %[offset_max]\n" + "uzp1 v7.2d, v8.2d, v11.2d\n" + "uzp2 v8.2d, v8.2d, v11.2d\n" + "ld1r { v1.4s }, [x20]\n" + "uzp1 v11.2d, v9.2d, v12.2d\n" + "uzp2 v9.2d, v9.2d, v12.2d\n" + "uzp1 v12.2d, v10.2d, v13.2d\n" + "uzp2 v10.2d, v10.2d, v13.2d\n" + "add x20, %x[args_ptr], %[offset_min]\n" + "ld1r { v0.4s }, [x20]\n" + "uzp1 v13.2d, v14.2d, v17.2d\n" + "uzp2 v14.2d, v14.2d, v17.2d\n" + "uzp1 v17.2d, v15.2d, v18.2d\n" + "uzp2 v15.2d, v15.2d, v18.2d\n" + "cmp x23, #0xc\n" + "uzp1 v18.2d, v16.2d, v19.2d\n" + "uzp2 v16.2d, v16.2d, v19.2d\n" + "uzp1 v19.2d, v20.2d, v23.2d\n" + "uzp2 v20.2d, v20.2d, v23.2d\n" + "uzp1 v23.2d, v21.2d, v24.2d\n" + "uzp2 v21.2d, v21.2d, v24.2d\n" + "uzp1 v24.2d, v22.2d, v25.2d\n" + "uzp2 v22.2d, v22.2d, v25.2d\n" + "uzp1 v25.2d, v26.2d, v29.2d\n" + "uzp2 v26.2d, v26.2d, v29.2d\n" + "uzp1 v29.2d, v27.2d, v30.2d\n" + "uzp2 v27.2d, v27.2d, v30.2d\n" + "uzp1 v30.2d, v28.2d, v31.2d\n" + "uzp2 v28.2d, v28.2d, v31.2d\n" + "fmin v7.4s, v7.4s, v1.4s\n" + "fmin v11.4s, v11.4s, v1.4s\n" + "fmin v12.4s, v12.4s, v1.4s\n" + "fmin v8.4s, v8.4s, v1.4s\n" + "fmin v9.4s, v9.4s, v1.4s\n" + "fmin v10.4s, v10.4s, v1.4s\n" + "fmin v13.4s, v13.4s, v1.4s\n" + "fmin v17.4s, v17.4s, v1.4s\n" + "fmin v18.4s, v18.4s, v1.4s\n" + "fmin v14.4s, v14.4s, v1.4s\n" + "fmin v15.4s, v15.4s, v1.4s\n" + "fmin v16.4s, v16.4s, v1.4s\n" + "fmin v19.4s, v19.4s, v1.4s\n" + "fmin v23.4s, v23.4s, v1.4s\n" + "fmin v24.4s, v24.4s, v1.4s\n" + "fmin v20.4s, v20.4s, v1.4s\n" + "fmin v21.4s, v21.4s, v1.4s\n" + "fmin v22.4s, v22.4s, v1.4s\n" + "fmin v25.4s, v25.4s, v1.4s\n" + "fmin v29.4s, v29.4s, v1.4s\n" + "fmin v30.4s, v30.4s, v1.4s\n" + "fmin v26.4s, v26.4s, v1.4s\n" + "fmin v27.4s, v27.4s, v1.4s\n" + "fmin v28.4s, v28.4s, v1.4s\n" + "fmax v7.4s, v7.4s, v0.4s\n" + "fmax v11.4s, v11.4s, v0.4s\n" + "fmax v12.4s, v12.4s, v0.4s\n" + "fmax v8.4s, v8.4s, v0.4s\n" + "fmax v9.4s, v9.4s, v0.4s\n" + "fmax v10.4s, v10.4s, v0.4s\n" + "fmax v13.4s, v13.4s, v0.4s\n" + "fmax v17.4s, v17.4s, v0.4s\n" + "fmax v18.4s, v18.4s, v0.4s\n" + "fmax v14.4s, v14.4s, v0.4s\n" + "fmax v15.4s, v15.4s, v0.4s\n" + "fmax v16.4s, v16.4s, v0.4s\n" + "fmax v19.4s, v19.4s, v0.4s\n" + "fmax v23.4s, v23.4s, v0.4s\n" + "fmax v24.4s, v24.4s, v0.4s\n" + "fmax v20.4s, v20.4s, v0.4s\n" + "fmax v21.4s, v21.4s, v0.4s\n" + "fmax v22.4s, v22.4s, v0.4s\n" + "fmax v25.4s, v25.4s, v0.4s\n" + "fmax v29.4s, v29.4s, v0.4s\n" + "fmax v30.4s, v30.4s, v0.4s\n" + "fmax v26.4s, v26.4s, v0.4s\n" + "fmax v27.4s, v27.4s, v0.4s\n" + "fmax v28.4s, v28.4s, v0.4s\n" + "blt 7f\n" + "str q26, [x24, #0x0]\n" + "str q27, [x24, #0x10]\n" + "str q28, [x24, #0x20]\n" + "add x24, x24, #0x30\n" + "str q25, [x9, #0x0]\n" + "str q29, [x9, #0x10]\n" + "str q30, [x9, #0x20]\n" + "add x9, x9, #0x30\n" + "str q20, [x25, #0x0]\n" + "str q21, [x25, #0x10]\n" + "str q22, [x25, #0x20]\n" + "add x25, x25, #0x30\n" + "str q19, [x11, #0x0]\n" + "str q23, [x11, #0x10]\n" + "str q24, [x11, #0x20]\n" + "add x11, x11, #0x30\n" + "str q14, [x26, #0x0]\n" + "str q15, [x26, #0x10]\n" + "str q16, [x26, #0x20]\n" + "add x26, x26, #0x30\n" + "str q13, [x10, #0x0]\n" + "str q17, [x10, #0x10]\n" + "str q18, [x10, #0x20]\n" + "add x10, x10, #0x30\n" + "str q8, [x27, #0x0]\n" + "str q9, [x27, #0x10]\n" + "str q10, [x27, #0x20]\n" + "add x27, x27, #0x30\n" + "str q7, [%x[Cpanel], #0x0]\n" + "str q11, [%x[Cpanel], #0x10]\n" + "str q12, [%x[Cpanel], #0x20]\n" + "add %x[Cpanel], %x[Cpanel], #0x30\n" + "b 14f\n" + "7:" // partial output + "tbz x23, #3, 9f\n" + "st1 { v26.4s }, [x24], #0x10\n" + "st1 { v27.4s }, [x24], #0x10\n" + "st1 { v25.4s }, [x9], #0x10\n" + "st1 { v29.4s }, [x9], #0x10\n" + "st1 { v20.4s }, [x25], #0x10\n" + "st1 { v21.4s }, [x25], #0x10\n" + "st1 { v19.4s }, [x11], #0x10\n" + "st1 { v23.4s }, [x11], #0x10\n" + "st1 { v14.4s }, [x26], #0x10\n" + "st1 { v15.4s }, [x26], #0x10\n" + "st1 { v13.4s }, [x10], #0x10\n" + "st1 { v17.4s }, [x10], #0x10\n" + "st1 { v8.4s }, [x27], #0x10\n" + "st1 { v9.4s }, [x27], #0x10\n" + "st1 { v7.4s }, [%x[Cpanel]], #0x10\n" + "st1 { v11.4s }, [%x[Cpanel]], #0x10\n" + "tbz x23, #1, 8f\n" + "str d28, [x24], #0x8\n" + "str d30, [x9], #0x8\n" + "str d22, [x25], #0x8\n" + "str d24, [x11], #0x8\n" + "str d16, [x26], #0x8\n" + "str d18, [x10], #0x8\n" + "str d10, [x27], #0x8\n" + "str d12, [%x[Cpanel]], #0x8\n" + "tbz x23, #0, 13f\n" + "st1 { v28.s }[2], [x24]\n" + "st1 { v30.s }[2], [x9]\n" + "st1 { v22.s }[2], [x25]\n" + "st1 { v24.s }[2], [x11]\n" + "st1 { v16.s }[2], [x26]\n" + "st1 { v18.s }[2], [x10]\n" + "st1 { v10.s }[2], [x27]\n" + "st1 { v12.s }[2], [%x[Cpanel]]\n" + "b 13f\n" + "8:" // partial result store: partial_1_8 + "tbz x23, #0, 13f\n" + "str s28, [x24, #0x0]\n" + "str s30, [x9, #0x0]\n" + "str s22, [x25, #0x0]\n" + "str s24, [x11, #0x0]\n" + "str s16, [x26, #0x0]\n" + "str s18, [x10, #0x0]\n" + "str s10, [x27, #0x0]\n" + "str s12, [%x[Cpanel], #0x0]\n" + "b 13f\n" + "9:" // partial result store: partial_4_0 + "tbz x23, #2, 11f\n" + "st1 { v26.4s }, [x24], #0x10\n" + "st1 { v25.4s }, [x9], #0x10\n" + "st1 { v20.4s }, [x25], #0x10\n" + "st1 { v19.4s }, [x11], #0x10\n" + "st1 { v14.4s }, [x26], #0x10\n" + "st1 { v13.4s }, [x10], #0x10\n" + "st1 { v8.4s }, [x27], #0x10\n" + "st1 { v7.4s }, [%x[Cpanel]], #0x10\n" + "tbz x23, #1, 10f\n" + "str d27, [x24], #0x8\n" + "str d29, [x9], #0x8\n" + "str d21, [x25], #0x8\n" + "str d23, [x11], #0x8\n" + "str d15, [x26], #0x8\n" + "str d17, [x10], #0x8\n" + "str d9, [x27], #0x8\n" + "str d11, [%x[Cpanel]], #0x8\n" + "tbz x23, #0, 13f\n" + "st1 { v27.s }[2], [x24]\n" + "st1 { v29.s }[2], [x9]\n" + "st1 { v21.s }[2], [x25]\n" + "st1 { v23.s }[2], [x11]\n" + "st1 { v15.s }[2], [x26]\n" + "st1 { v17.s }[2], [x10]\n" + "st1 { v9.s }[2], [x27]\n" + "st1 { v11.s }[2], [%x[Cpanel]]\n" + "b 13f\n" + "10:" // partial result store: partial_1_4 + "tbz x23, #0, 13f\n" + "str s27, [x24, #0x0]\n" + "str s29, [x9, #0x0]\n" + "str s21, [x25, #0x0]\n" + "str s23, [x11, #0x0]\n" + "str s15, [x26, #0x0]\n" + "str s17, [x10, #0x0]\n" + "str s9, [x27, #0x0]\n" + "str s11, [%x[Cpanel], #0x0]\n" + "b 13f\n" + "11:" // partial result store: partial_2_0 + "tbz x23, #1, 12f\n" + "str d26, [x24], #0x8\n" + "str d25, [x9], #0x8\n" + "str d20, [x25], #0x8\n" + "str d19, [x11], #0x8\n" + "str d14, [x26], #0x8\n" + "str d13, [x10], #0x8\n" + "str d8, [x27], #0x8\n" + "str d7, [%x[Cpanel]], #0x8\n" + "tbz x23, #0, 13f\n" + "st1 { v26.s }[2], [x24]\n" + "st1 { v25.s }[2], [x9]\n" + "st1 { v20.s }[2], [x25]\n" + "st1 { v19.s }[2], [x11]\n" + "st1 { v14.s }[2], [x26]\n" + "st1 { v13.s }[2], [x10]\n" + "st1 { v8.s }[2], [x27]\n" + "st1 { v7.s }[2], [%x[Cpanel]]\n" + "b 13f\n" + "12:" // partial result store: partial_1_0 + "str s26, [x24, #0x0]\n" + "str s25, [x9, #0x0]\n" + "str s20, [x25, #0x0]\n" + "str s19, [x11, #0x0]\n" + "str s14, [x26, #0x0]\n" + "str s13, [x10, #0x0]\n" + "str s8, [x27, #0x0]\n" + "str s7, [%x[Cpanel], #0x0]\n" + "13:" // partial result store: Done + "14:" // store done + "subs x23, x23, #0xc\n" + "bgt 3b\n" + "subs %x[M], %x[M], #0x8\n" + "mov %x[Cpanel], x28\n" + "bgt 1b\n" + : [Apanel] "+&r"(Apanel), [Cpanel] "+&r"(Cpanel), [M] "+&r"(M) + : [args_ptr] "r"(&ka), [ldc] "r"(ldc * sizeof(float)), [offset_max] "I"(offsetof(KernelArgs, maxval)), + [offset_min] "I"(offsetof(KernelArgs, minval)), [offsetof_Bpanel] "I"(offsetof(KernelArgs, Bpanel)), + [offsetof_K] "I"(offsetof(KernelArgs, K)), [offsetof_N] "I"(offsetof(KernelArgs, N)) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31", "x9", "x10", "x11", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_bf16_bf16_f32p/kai_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla.h b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.h similarity index 64% rename from kai/ukernels/matmul/matmul_clamp_bf16_bf16_f32p/kai_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla.h rename to kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.h index 56425303..000d690a 100644 --- a/kai/ukernels/matmul/matmul_clamp_bf16_bf16_f32p/kai_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.h @@ -23,37 +23,42 @@ extern "C" { /// The starting row index must be divisible by `m_step`. /// /// @return The m step value. -size_t kai_get_m_step_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla(void); +size_t kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(void); /// Gets n step value. /// /// The starting column index must be divisible by `n_step`. /// /// @return The n step value. -size_t kai_get_n_step_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla(void); +size_t kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(void); -size_t kai_get_mr_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla(void); +/// Gets mr value. +/// +/// This is the packing parameter which must be used to pack the LHS matrix. +/// +/// @return The mr value. +size_t kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(void); /// Gets nr value. /// /// This is the packing parameter which must be used to pack the RHS matrix. /// /// @return The nr value. -size_t kai_get_nr_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla(void); +size_t kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(void); /// Gets kr value. /// -/// This is the packing parameter which must be used to pack the RHS matrix. +/// This is the packing parameter which must be used to pack the LHS & RHS matrices. /// /// @return The kr value. -size_t kai_get_kr_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla(void); +size_t kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(void); /// Gets sr value. /// /// This is the packing parameter which must be used to pack the RHS matrix. /// /// @return The sr value. -size_t kai_get_sr_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla(void); +size_t kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(void); /// Gets the offset in bytes to the data element in the LHS matrix buffer. /// @@ -61,7 +66,7 @@ size_t kai_get_sr_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla(void); /// @param[in] stride Row stride in bytes. /// /// @return The offset in bytes to the data element. -size_t kai_get_lhs_offset_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla(size_t m_idx, size_t stride); +size_t kai_get_lhs_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(size_t m_idx, size_t stride); /// Gets the offset in bytes to the data element in the packed RHS matrix buffer. /// @@ -69,7 +74,7 @@ size_t kai_get_lhs_offset_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mml /// @param[in] k Number of columns. /// /// @return The offset in bytes to the data element. -size_t kai_get_rhs_packed_offset_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla(size_t n_idx, size_t k); +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(size_t n_idx, size_t k); /// Gets the offset in bytes to the data element in the destination matrix buffer. /// @@ -78,7 +83,8 @@ size_t kai_get_rhs_packed_offset_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_n /// @param[in] stride Row stride in bytes. /// /// @return The offset in bytes to the data element. -size_t kai_get_dst_offset_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla(size_t m_idx, size_t n_idx, size_t stride); +size_t kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla( + size_t m_idx, size_t n_idx, size_t stride); /// Gets the size in bytes of the destination matrix buffer. /// @@ -86,33 +92,32 @@ size_t kai_get_dst_offset_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mml /// @param[in] n Number of columns. /// /// @return The size in bytes of the destination matrix buffer. -size_t kai_get_dst_size_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla(size_t m, size_t n); +size_t kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(size_t m, size_t n); /// Runs the matrix multiplication microkernel followed by a clamp operation. /// /// The pointer of each buffers (LHS, packed RHS and output) needs to be added with offset /// calculated using the following functions: /// -/// * LHS: @ref kai_get_lhs_offset_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla. -/// * Packed RHS: @ref kai_get_rhs_packed_offset_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla. -/// * Output: @ref kai_get_dst_offset_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla. +/// * Packed LHS: @ref kai_get_lhs_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla. +/// * Packed RHS: @ref kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla. +/// * Output: @ref kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla. /// /// @param[in] m Number of output rows to be computed. /// @param[in] n Number of output columns to be computed. /// @param[in] k Common dimension of the LHS and RHS operand. /// @param[in] lhs LHS matrix buffer. -/// @param[in] lhs_stride Row stride in bytes of the LHS matrix. /// @param[in] rhs_packed Packed RHS buffer. /// @param[out] dst Output matrix buffer. /// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. /// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(float) /// @param[in] clamp_min Minimum value to clamp the final result. /// @param[in] clamp_max Maximum value to clamp the final result. -void kai_run_matmul_clamp_bf16_bf16_f32p12x1biasf32_8x12x4_neon_mmla( - size_t m, size_t n, size_t k, // - const void* lhs_packed, size_t lhs_stride, // - const void* rhs_packed, // - void* dst, size_t dst_stride_row, size_t dst_stride_col, // +void kai_run_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla( + size_t m, size_t n, size_t k, // + const void* lhs_packed, // + const void* rhs_packed, // + void* dst, size_t dst_stride_row, size_t dst_stride_col, // float clamp_min, float clamp_max); #ifdef __cplusplus diff --git a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p_interface.h b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p_interface.h new file mode 100644 index 00000000..8eb1969a --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p_interface.h @@ -0,0 +1,57 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) +#error This file must be compiled for AArch64, FEAT_BF16. +#else // Architectural features check. + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// All micro-kernels variants of the same type share the same interfaces +// In this case, the micro-kernel type is: matmul_clamp_f32_bf16p_bf16p + +/// Micro-kernel helper functions ("get" methods) +typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_m_step_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_n_step_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_mr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_nr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_kr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_sr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_lhs_offset_func_t)(size_t m_idx, size_t lhs_stride); +typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_rhs_packed_offset_func_t)(size_t n_idx, size_t k); +typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_dst_offset_func_t)(size_t m_idx, size_t n_idx, size_t dst_stride); +typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_dst_size_func_t)(size_t m, size_t n); + +/// Micro-kernel core function ("run" method) +typedef void (*kai_matmul_clamp_f32_bf16p_bf16p_run_matmul_func_t)( + size_t m, size_t n, size_t k, const void* lhs, const void* rhs_packed, void* dst, size_t dst_stride_row, + size_t dst_stride_col, float scalar_min, float scalar_max); + +/// Micro-kernel interface +struct kai_matmul_clamp_f32_bf16p_bf16p_ukernel { + kai_matmul_clamp_f32_bf16p_bf16p_get_m_step_func_t get_m_step; + kai_matmul_clamp_f32_bf16p_bf16p_get_n_step_func_t get_n_step; + kai_matmul_clamp_f32_bf16p_bf16p_get_mr_func_t get_mr; + kai_matmul_clamp_f32_bf16p_bf16p_get_nr_func_t get_nr; + kai_matmul_clamp_f32_bf16p_bf16p_get_nr_func_t get_kr; + kai_matmul_clamp_f32_bf16p_bf16p_get_sr_func_t get_sr; + kai_matmul_clamp_f32_bf16p_bf16p_get_lhs_offset_func_t get_lhs_packed_offset; + kai_matmul_clamp_f32_bf16p_bf16p_get_rhs_packed_offset_func_t get_rhs_packed_offset; + kai_matmul_clamp_f32_bf16p_bf16p_get_dst_offset_func_t get_dst_offset; + kai_matmul_clamp_f32_bf16p_bf16p_get_dst_size_func_t get_dst_size; + kai_matmul_clamp_f32_bf16p_bf16p_run_matmul_func_t run_matmul; +}; + +#ifdef __cplusplus +} +#endif + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_lhs_pack_8x4_f32_bf16_neon.c b/kai/ukernels/matmul/pack/kai_lhs_pack_8x4_f32_bf16_neon.c deleted file mode 100644 index af8557d0..00000000 --- a/kai/ukernels/matmul/pack/kai_lhs_pack_8x4_f32_bf16_neon.c +++ /dev/null @@ -1,222 +0,0 @@ -// -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates -// -// SPDX-License-Identifier: Apache-2.0 -// - -#if !defined(__aarch64__) || !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) -#error This file must be compiled for AArch64, FEAT_BF16. -#else // Architectural features check. - -#include -#include -#include -#include "kai/kai_common.h" - -static const size_t kai_mr = 8; -static const size_t kai_kr = 4; -static const size_t kai_sr = 1; -static const size_t vec_len = 1; - -size_t kai_get_m_step_lhs_pack_8x4_f32_bf16_neon(size_t mr) { - KAI_ASSUME(mr == kai_mr * vec_len); - KAI_UNUSED(mr); - - return kai_mr * vec_len; -} - -size_t kai_get_lhs_offset_lhs_pack_8x4_f32_bf16_neon(size_t m_idx, size_t lhs_stride) { - KAI_ASSUME(m_idx % (kai_mr * vec_len) == 0); - - return m_idx * lhs_stride; -} - -size_t kai_get_lhs_packed_offset_lhs_pack_8x4_f32_bf16_neon(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) { - const size_t scaled_mr = kai_mr * vec_len; - KAI_ASSUME(m_idx % scaled_mr == 0); - KAI_ASSUME(mr == scaled_mr); - KAI_ASSUME(kr == kai_kr); - KAI_ASSUME(sr == kai_sr); - - KAI_UNUSED(mr); - KAI_UNUSED(kr); - KAI_UNUSED(sr); - - return m_idx * kai_roundup(k, kr) * sizeof(bfloat16_t); -} - -size_t kai_get_lhs_packed_size_lhs_pack_8x4_f32_bf16_neon(size_t m, size_t k, size_t mr, size_t kr, size_t sr) { - KAI_ASSUME(mr == kai_mr * vec_len); - KAI_ASSUME(kr == kai_kr); - KAI_ASSUME(sr == kai_sr); - - KAI_UNUSED(mr); - KAI_UNUSED(kr); - KAI_UNUSED(sr); - - return kai_roundup(m, kai_mr * vec_len) * kai_roundup(k, kai_kr) * sizeof(bfloat16_t); -} - -void kai_run_lhs_pack_8x4_f32_bf16_neon( - size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, - const void* lhs, size_t lhs_stride, void* lhs_packed -) -{ - KAI_ASSUME(mr == kai_mr * vec_len); - KAI_ASSUME(kr == kai_kr); - KAI_ASSUME(sr == kai_sr); - KAI_ASSUME(lhs != NULL); - KAI_ASSUME(lhs_packed != NULL); - - KAI_ASSUME(m_idx_start == 0); - - const size_t block_height = kai_mr * vec_len; - const size_t row_offset = 0; - - const void* in[block_height]; - - for (size_t block_y = 0; block_y < m; block_y += block_height) { - const size_t height = KAI_MIN(m - block_y, block_height); - void* out = (char*)lhs_packed + block_y * kai_roundup(k,kr) * sizeof(bfloat16_t); - size_t width = k; - - for (size_t y = 0; y < height; y++) { - in[y] = (char*)lhs + (block_y + y) * lhs_stride; - } - - __asm__ __volatile__( - "ldr x28, [%x[in], #0x0]\n" - "ldr x27, [%x[in], #0x8]\n" - "cmp %x[height], #0x8\n" - "ldr x26, [%x[in], #0x10]\n" - "ldr x25, [%x[in], #0x18]\n" - "ldr x24, [%x[in], #0x20]\n" - "ldr x23, [%x[in], #0x28]\n" - "ldr x22, [%x[in], #0x30]\n" - "ldr x21, [%x[in], #0x38]\n" - "add x28, x28, %x[row_offset], LSL #2\n" - "add x27, x27, %x[row_offset], LSL #2\n" - "add x26, x26, %x[row_offset], LSL #2\n" - "add x25, x25, %x[row_offset], LSL #2\n" - "add x24, x24, %x[row_offset], LSL #2\n" - "add x23, x23, %x[row_offset], LSL #2\n" - "add x22, x22, %x[row_offset], LSL #2\n" - "add x21, x21, %x[row_offset], LSL #2\n" - "beq 1f\n" - "cmp %x[height], #0x2\n" - "mov x21, x28\n" - "csel x27, x27, x28, GE\n" - "csel x26, x26, x28, GT\n" - "cmp %x[height], #0x4\n" - "csel x25, x25, x28, GE\n" - "csel x24, x24, x28, GT\n" - "cmp %x[height], #0x6\n" - "csel x23, x23, x28, GE\n" - "csel x22, x22, x28, GT\n" - "1:" // no_pointer_adj - "cmp %x[width], #0x4\n" - "prfm pldl1keep, [x28, #0x0]\n" - "prfm pldl1keep, [x27, #0x0]\n" - "prfm pldl1keep, [x26, #0x0]\n" - "prfm pldl1keep, [x25, #0x0]\n" - "prfm pldl1keep, [x24, #0x0]\n" - "prfm pldl1keep, [x23, #0x0]\n" - "prfm pldl1keep, [x22, #0x0]\n" - "prfm pldl1keep, [x21, #0x0]\n" - "prfm pldl1keep, [x28, #0x40]\n" - "prfm pldl1keep, [x27, #0x40]\n" - "prfm pldl1keep, [x26, #0x40]\n" - "prfm pldl1keep, [x25, #0x40]\n" - "prfm pldl1keep, [x24, #0x40]\n" - "prfm pldl1keep, [x23, #0x40]\n" - "prfm pldl1keep, [x22, #0x40]\n" - "prfm pldl1keep, [x21, #0x40]\n" - "blt 3f\n" - "2:" // Main loop head - "ldr q19, [x28], #0x10\n" - "ldr q18, [x26], #0x10\n" - "subs %x[width], %x[width], #0x4\n" - "ldr q17, [x24], #0x10\n" - "ldr q16, [x22], #0x10\n" - "cmp %x[width], #0x4\n" - "ldr q23, [x27], #0x10\n" - "ldr q22, [x25], #0x10\n" - "ldr q21, [x23], #0x10\n" - "ldr q20, [x21], #0x10\n" - ".inst 0x0ea16a73 // bfcvtn v19.4h, v19.4s\n" - ".inst 0x0ea16a52 // bfcvtn v18.4h, v18.4s\n" - ".inst 0x0ea16a31 // bfcvtn v17.4h, v17.4s\n" - ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n" - "prfm pldl1keep, [x28, #0x70]\n" - "prfm pldl1keep, [x27, #0x70]\n" - "prfm pldl1keep, [x26, #0x70]\n" - "prfm pldl1keep, [x25, #0x70]\n" - "prfm pldl1keep, [x24, #0x70]\n" - "prfm pldl1keep, [x23, #0x70]\n" - ".inst 0x4ea16af3 // bfcvtn2 v19.8h, v23.4s\n" - ".inst 0x4ea16ad2 // bfcvtn2 v18.8h, v22.4s\n" - "prfm pldl1keep, [x22, #0x70]\n" - "prfm pldl1keep, [x21, #0x70]\n" - ".inst 0x4ea16ab1 // bfcvtn2 v17.8h, v21.4s\n" - ".inst 0x4ea16a90 // bfcvtn2 v16.8h, v20.4s\n" - "str q19, [%x[out_ptr], #0x0]\n" - "str q18, [%x[out_ptr], #0x10]\n" - "str q17, [%x[out_ptr], #0x20]\n" - "str q16, [%x[out_ptr], #0x30]\n" - "add %x[out_ptr], %x[out_ptr], #0x40\n" - "bge 2b\n" - "3:" // Main loop skip - "cbz %x[width], 6f\n" - "tbz %x[width], #1, 4f\n" - "ldr d19, [x28], #0x8\n" - "ldr d23, [x27], #0x8\n" - "mov x20, #0x1\n" - "ldr d18, [x26], #0x8\n" - "ldr d22, [x25], #0x8\n" - "ldr d17, [x24], #0x8\n" - "ldr d21, [x23], #0x8\n" - "ldr d16, [x22], #0x8\n" - "ldr d20, [x21], #0x8\n" - "tbz %x[width], #0, 5f\n" - "ld1 { v19.s }[2], [x28]\n" - "ld1 { v23.s }[2], [x27]\n" - "ld1 { v18.s }[2], [x26]\n" - "ld1 { v22.s }[2], [x25]\n" - "ld1 { v17.s }[2], [x24]\n" - "ld1 { v21.s }[2], [x23]\n" - "ld1 { v16.s }[2], [x22]\n" - "ld1 { v20.s }[2], [x21]\n" - "b 5f\n" - "4:" // odd_loads_1_0 - "ldr s19, [x28, #0x0]\n" - "ldr s23, [x27, #0x0]\n" - "mov x20, #0x1\n" - "ldr s18, [x26, #0x0]\n" - "ldr s22, [x25, #0x0]\n" - "ldr s17, [x24, #0x0]\n" - "ldr s21, [x23, #0x0]\n" - "ldr s16, [x22, #0x0]\n" - "ldr s20, [x21, #0x0]\n" - "5:" // Odd load end - ".inst 0x0ea16a73 // bfcvtn v19.4h, v19.4s\n" - ".inst 0x0ea16a52 // bfcvtn v18.4h, v18.4s\n" - ".inst 0x0ea16a31 // bfcvtn v17.4h, v17.4s\n" - ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n" - ".inst 0x4ea16af3 // bfcvtn2 v19.8h, v23.4s\n" - ".inst 0x4ea16ad2 // bfcvtn2 v18.8h, v22.4s\n" - ".inst 0x4ea16ab1 // bfcvtn2 v17.8h, v21.4s\n" - ".inst 0x4ea16a90 // bfcvtn2 v16.8h, v20.4s\n" - "str q19, [%x[out_ptr], #0x0]\n" - "str q18, [%x[out_ptr], #0x10]\n" - "str q17, [%x[out_ptr], #0x20]\n" - "str q16, [%x[out_ptr], #0x30]\n" - "add %x[out_ptr], %x[out_ptr], #0x40\n" - "6:" // Odds skip - : [out_ptr] "+&r" (out), [width] "+&r" (width) - : [height] "r" (height), [in] "r" (in), [row_offset] "r" (row_offset) - : "cc", "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" - ); - } -} - -#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_lhs_pack_f32p8x4_bf16_neon.c b/kai/ukernels/matmul/pack/kai_lhs_pack_f32p8x4_bf16_neon.c new file mode 100644 index 00000000..32c845f7 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_pack_f32p8x4_bf16_neon.c @@ -0,0 +1,212 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) +#error This file must be compiled for AArch64, FEAT_BF16. +#else // Architectural features check. + +#include +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_mr = 8; +static const size_t kai_kr = 4; +static const size_t kai_sr = 1; + +size_t kai_get_m_step_lhs_pack_f32p8x4_bf16_neon(size_t mr) { + KAI_ASSUME(mr == kai_mr); + KAI_UNUSED(mr); + + return kai_mr; +} + +size_t kai_get_lhs_offset_lhs_pack_f32p8x4_bf16_neon(size_t m_idx, size_t lhs_stride) { + KAI_ASSUME(m_idx % (kai_mr) == 0); + + return m_idx * lhs_stride; +} + +size_t kai_get_lhs_packed_offset_lhs_pack_f32p8x4_bf16_neon(size_t m_idx, size_t k) { + KAI_ASSUME(m_idx % kai_mr == 0); + + return m_idx * kai_roundup(k, kai_kr) * sizeof(bfloat16_t); +} + +size_t kai_get_lhs_packed_size_lhs_pack_f32p8x4_bf16_neon(size_t m, size_t k, size_t mr, size_t kr, size_t sr) { + KAI_ASSUME(mr == kai_mr); + KAI_ASSUME(kr == kai_kr); + KAI_ASSUME(sr == kai_sr); + + KAI_UNUSED(mr); + KAI_UNUSED(kr); + KAI_UNUSED(sr); + + return kai_roundup(m, kai_mr) * kai_roundup(k, kai_kr) * sizeof(bfloat16_t); +} + +void kai_run_lhs_pack_f32p8x4_bf16_neon( + size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride, + void* lhs_packed) { + KAI_ASSUME(mr == kai_mr); + KAI_ASSUME(kr == kai_kr); + KAI_ASSUME(sr == kai_sr); + KAI_ASSUME(lhs != NULL); + KAI_ASSUME(lhs_packed != NULL); + + KAI_ASSUME(m_idx_start == 0); + + const size_t block_height = kai_mr; + const size_t row_offset = 0; + + const void* in[block_height]; + + for (size_t block_y = 0; block_y < m; block_y += block_height) { + const size_t height = KAI_MIN(m - block_y, block_height); + void* out = (char*)lhs_packed + block_y * kai_roundup(k, kr) * sizeof(bfloat16_t); + size_t width = k; + + for (size_t y = 0; y < height; y++) { + in[y] = (char*)lhs + (block_y + y) * lhs_stride; + } + + __asm__ __volatile__( + "ldr x28, [%x[in], #0x0]\n" + "ldr x27, [%x[in], #0x8]\n" + "cmp %x[height], #0x8\n" + "ldr x26, [%x[in], #0x10]\n" + "ldr x25, [%x[in], #0x18]\n" + "ldr x24, [%x[in], #0x20]\n" + "ldr x23, [%x[in], #0x28]\n" + "ldr x22, [%x[in], #0x30]\n" + "ldr x21, [%x[in], #0x38]\n" + "add x28, x28, %x[row_offset], LSL #2\n" + "add x27, x27, %x[row_offset], LSL #2\n" + "add x26, x26, %x[row_offset], LSL #2\n" + "add x25, x25, %x[row_offset], LSL #2\n" + "add x24, x24, %x[row_offset], LSL #2\n" + "add x23, x23, %x[row_offset], LSL #2\n" + "add x22, x22, %x[row_offset], LSL #2\n" + "add x21, x21, %x[row_offset], LSL #2\n" + "beq 1f\n" + "cmp %x[height], #0x2\n" + "mov x21, x28\n" + "csel x27, x27, x28, GE\n" + "csel x26, x26, x28, GT\n" + "cmp %x[height], #0x4\n" + "csel x25, x25, x28, GE\n" + "csel x24, x24, x28, GT\n" + "cmp %x[height], #0x6\n" + "csel x23, x23, x28, GE\n" + "csel x22, x22, x28, GT\n" + "1:" // no_pointer_adj + "cmp %x[width], #0x4\n" + "prfm pldl1keep, [x28, #0x0]\n" + "prfm pldl1keep, [x27, #0x0]\n" + "prfm pldl1keep, [x26, #0x0]\n" + "prfm pldl1keep, [x25, #0x0]\n" + "prfm pldl1keep, [x24, #0x0]\n" + "prfm pldl1keep, [x23, #0x0]\n" + "prfm pldl1keep, [x22, #0x0]\n" + "prfm pldl1keep, [x21, #0x0]\n" + "prfm pldl1keep, [x28, #0x40]\n" + "prfm pldl1keep, [x27, #0x40]\n" + "prfm pldl1keep, [x26, #0x40]\n" + "prfm pldl1keep, [x25, #0x40]\n" + "prfm pldl1keep, [x24, #0x40]\n" + "prfm pldl1keep, [x23, #0x40]\n" + "prfm pldl1keep, [x22, #0x40]\n" + "prfm pldl1keep, [x21, #0x40]\n" + "blt 3f\n" + "2:" // Main loop head + "ldr q19, [x28], #0x10\n" + "ldr q18, [x26], #0x10\n" + "subs %x[width], %x[width], #0x4\n" + "ldr q17, [x24], #0x10\n" + "ldr q16, [x22], #0x10\n" + "cmp %x[width], #0x4\n" + "ldr q23, [x27], #0x10\n" + "ldr q22, [x25], #0x10\n" + "ldr q21, [x23], #0x10\n" + "ldr q20, [x21], #0x10\n" + ".inst 0x0ea16a73 // bfcvtn v19.4h, v19.4s\n" + ".inst 0x0ea16a52 // bfcvtn v18.4h, v18.4s\n" + ".inst 0x0ea16a31 // bfcvtn v17.4h, v17.4s\n" + ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n" + "prfm pldl1keep, [x28, #0x70]\n" + "prfm pldl1keep, [x27, #0x70]\n" + "prfm pldl1keep, [x26, #0x70]\n" + "prfm pldl1keep, [x25, #0x70]\n" + "prfm pldl1keep, [x24, #0x70]\n" + "prfm pldl1keep, [x23, #0x70]\n" + ".inst 0x4ea16af3 // bfcvtn2 v19.8h, v23.4s\n" + ".inst 0x4ea16ad2 // bfcvtn2 v18.8h, v22.4s\n" + "prfm pldl1keep, [x22, #0x70]\n" + "prfm pldl1keep, [x21, #0x70]\n" + ".inst 0x4ea16ab1 // bfcvtn2 v17.8h, v21.4s\n" + ".inst 0x4ea16a90 // bfcvtn2 v16.8h, v20.4s\n" + "str q19, [%x[out_ptr], #0x0]\n" + "str q18, [%x[out_ptr], #0x10]\n" + "str q17, [%x[out_ptr], #0x20]\n" + "str q16, [%x[out_ptr], #0x30]\n" + "add %x[out_ptr], %x[out_ptr], #0x40\n" + "bge 2b\n" + "3:" // Main loop skip + "cbz %x[width], 6f\n" + "tbz %x[width], #1, 4f\n" + "ldr d19, [x28], #0x8\n" + "ldr d23, [x27], #0x8\n" + "mov x20, #0x1\n" + "ldr d18, [x26], #0x8\n" + "ldr d22, [x25], #0x8\n" + "ldr d17, [x24], #0x8\n" + "ldr d21, [x23], #0x8\n" + "ldr d16, [x22], #0x8\n" + "ldr d20, [x21], #0x8\n" + "tbz %x[width], #0, 5f\n" + "ld1 { v19.s }[2], [x28]\n" + "ld1 { v23.s }[2], [x27]\n" + "ld1 { v18.s }[2], [x26]\n" + "ld1 { v22.s }[2], [x25]\n" + "ld1 { v17.s }[2], [x24]\n" + "ld1 { v21.s }[2], [x23]\n" + "ld1 { v16.s }[2], [x22]\n" + "ld1 { v20.s }[2], [x21]\n" + "b 5f\n" + "4:" // odd_loads_1_0 + "ldr s19, [x28, #0x0]\n" + "ldr s23, [x27, #0x0]\n" + "mov x20, #0x1\n" + "ldr s18, [x26, #0x0]\n" + "ldr s22, [x25, #0x0]\n" + "ldr s17, [x24, #0x0]\n" + "ldr s21, [x23, #0x0]\n" + "ldr s16, [x22, #0x0]\n" + "ldr s20, [x21, #0x0]\n" + "5:" // Odd load end + ".inst 0x0ea16a73 // bfcvtn v19.4h, v19.4s\n" + ".inst 0x0ea16a52 // bfcvtn v18.4h, v18.4s\n" + ".inst 0x0ea16a31 // bfcvtn v17.4h, v17.4s\n" + ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n" + ".inst 0x4ea16af3 // bfcvtn2 v19.8h, v23.4s\n" + ".inst 0x4ea16ad2 // bfcvtn2 v18.8h, v22.4s\n" + ".inst 0x4ea16ab1 // bfcvtn2 v17.8h, v21.4s\n" + ".inst 0x4ea16a90 // bfcvtn2 v16.8h, v20.4s\n" + "str q19, [%x[out_ptr], #0x0]\n" + "str q18, [%x[out_ptr], #0x10]\n" + "str q17, [%x[out_ptr], #0x20]\n" + "str q16, [%x[out_ptr], #0x30]\n" + "add %x[out_ptr], %x[out_ptr], #0x40\n" + "6:" // Odds skip + : [out_ptr] "+&r"(out), [width] "+&r"(width) + : [height] "r"(height), [in] "r"(in), [row_offset] "r"(row_offset) + : "cc", "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "x20", "x21", "x22", "x23", "x24", + "x25", "x26", "x27", "x28"); + } +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_lhs_pack_8x4_f32_bf16_neon.h b/kai/ukernels/matmul/pack/kai_lhs_pack_f32p8x4_bf16_neon.h similarity index 51% rename from kai/ukernels/matmul/pack/kai_lhs_pack_8x4_f32_bf16_neon.h rename to kai/ukernels/matmul/pack/kai_lhs_pack_f32p8x4_bf16_neon.h index 5cd514c9..dd47ba88 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_pack_8x4_f32_bf16_neon.h +++ b/kai/ukernels/matmul/pack/kai_lhs_pack_f32p8x4_bf16_neon.h @@ -11,21 +11,21 @@ extern "C" { #include #include + #include "kai/kai_common.h" -size_t kai_get_m_step_lhs_pack_8x4_f32_bf16_neon(size_t mr); +size_t kai_get_m_step_lhs_pack_f32p8x4_bf16_neon(size_t mr); -size_t kai_get_lhs_offset_lhs_pack_8x4_f32_bf16_neon(size_t m_idx, size_t lhs_stride); +size_t kai_get_lhs_offset_lhs_pack_f32p8x4_bf16_neon(size_t m_idx, size_t lhs_stride); -size_t kai_get_lhs_packed_offset_lhs_pack_8x4_f32_bf16_neon(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr); +size_t kai_get_lhs_packed_offset_lhs_pack_f32p8x4_bf16_neon(size_t m_idx, size_t k); -size_t kai_get_lhs_packed_size_lhs_pack_8x4_f32_bf16_neon(size_t m, size_t k, size_t mr, size_t kr, size_t sr); +size_t kai_get_lhs_packed_size_lhs_pack_f32p8x4_bf16_neon(size_t m, size_t k, size_t mr, size_t kr, size_t sr); -void kai_run_lhs_pack_8x4_f32_bf16_neon( - size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, - const void* lhs, size_t lhs_stride, void* lhs_packed -); +void kai_run_lhs_pack_f32p8x4_bf16_neon( + size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride, + void* lhs_packed); #ifdef __cplusplus } // extern "C" -#endif // __cplusplus \ No newline at end of file +#endif // __cplusplus diff --git a/kai/ukernels/matmul/pack/kai_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12.c b/kai/ukernels/matmul/pack/kai_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12.c deleted file mode 100644 index 36f60c26..00000000 --- a/kai/ukernels/matmul/pack/kai_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12.c +++ /dev/null @@ -1,474 +0,0 @@ -// -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates -// -// SPDX-License-Identifier: Apache-2.0 -// - -#if !defined(__aarch64__) -#error This file must be compiled for AArch64. -#else // Architectural features check. - -#include - -#include -#include -#include -#include - -#include "kai/kai_common.h" - -static const size_t kai_nr = 12; -static const size_t kai_kr = 4; - -size_t kai_get_n_step_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12(void) { - return kai_nr; -} - -size_t kai_get_rhs_offset_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12(size_t n_idx) { - KAI_ASSUME(n_idx % kai_nr == 0); - - return n_idx * sizeof(float); -} - - -size_t kai_get_bias_offset_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12(size_t n_idx) { - return n_idx * sizeof(uint32_t); -} - - -size_t kai_get_rhs_packed_offset_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12(size_t n_idx, size_t k) { - KAI_ASSUME(n_idx % kai_nr == 0); - - return n_idx * (sizeof(uint32_t) + kai_roundup(k, kai_kr) * sizeof(bfloat16_t)); -} - -size_t kai_get_rhs_packed_size_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12(size_t n, size_t k) { - return kai_get_rhs_packed_offset_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12(kai_roundup(n, kai_nr), k); -} - -void kai_run_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12( - size_t num_groups, - size_t n, - size_t k, - size_t nr, - size_t kr, - size_t sr, - size_t rhs_stride, - const void* rhs, - const void* bias, - const void* scale, - void* rhs_packed, - size_t extra_bytes, - const void* params) { - KAI_ASSUME(num_groups == 1); - KAI_ASSUME(nr == kai_nr); - KAI_ASSUME(kr == kai_kr); - KAI_ASSUME(sr == 1); - KAI_ASSUME(rhs != NULL); - KAI_ASSUME(bias != NULL); - KAI_ASSUME(scale == NULL); - KAI_ASSUME(rhs_packed != NULL); - KAI_ASSUME(extra_bytes == 0); - KAI_ASSUME(params == NULL); - - size_t height = k; - const size_t width = n; - const void* in = rhs; - void* out = rhs_packed; - const size_t in_stride = rhs_stride; - float *pad_row = (float*)alloca(width * sizeof(float)); - - if (height % 4) { - memset(pad_row, 0, width * sizeof(float)); - } - - size_t out_stride = kai_nr * kai_roundup(height, 4) * sizeof(bfloat16_t) + kai_nr * sizeof(uint32_t); - - __asm__ __volatile__( - "mov x22, %x[width]\n" - "mov x21, %x[out]\n" - "cmp x22, #0xc\n" - "blt 2f\n" - "1:" // Bias: Full loop - "ldr q16, [%x[bias], #0x0]\n" - "ldr q26, [%x[bias], #0x10]\n" - "sub x22, x22, #0xc\n" - "ldr q8, [%x[bias], #0x20]\n" - "cmp x22, #0xc\n" - "add %x[bias], %x[bias], #0x30\n" - "str q16, [x21, #0x0]\n" - "str q26, [x21, #0x10]\n" - "str q8, [x21, #0x20]\n" - "add x21, x21, %x[out_stride]\n" - "bge 1b\n" - "cbz x22, 3f\n" - "2:" // Bias: Tail loop - "ldr w20, [%x[bias], #0x0]\n" - "sub x22, x22, #0x1\n" - "add %x[bias], %x[bias], #0x4\n" - "cmp x22, #0x0\n" - "str w20, [x21]\n" - "add x21, x21, #0x4\n" - "bgt 2b\n" - "3:" // Bias: Done - "cmp %x[height], #0x8\n" - "add %x[out], %x[out], #0x30\n" - "blt 12f\n" - "4:" // Main row loop: Head - "mov x9, %x[in]\n" - "mov x28, %x[width]\n" - "mov x27, %x[out]\n" - "sub %x[height], %x[height], #0x8\n" - "add x26, x9, %x[in_stride]\n" - "add x25, x26, %x[in_stride]\n" - "add x24, x25, %x[in_stride]\n" - "cmp x28, #0xc\n" - "add x23, x24, %x[in_stride]\n" - "add x22, x23, %x[in_stride]\n" - "add x21, x22, %x[in_stride]\n" - "add x20, x21, %x[in_stride]\n" - "add %x[in], x20, %x[in_stride]\n" - "blt 6f\n" - "5:" // Main row loop: Column loop - "ldr q28, [x9], #0x10\n" - "ldr q27, [x26], #0x10\n" - "sub x28, x28, #0xc\n" - "ldr q11, [x25], #0x10\n" - "ldr q5, [x24], #0x10\n" - "cmp x28, #0xc\n" - "ldr q14, [x23], #0x10\n" - "ldr q6, [x22], #0x10\n" - "ldr q2, [x21], #0x10\n" - "ldr q18, [x20], #0x10\n" - "ldr q1, [x9], #0x10\n" - "ldr q7, [x26], #0x10\n" - "zip1 v15.4s, v28.4s, v11.4s\n" - "zip1 v8.4s, v27.4s, v5.4s\n" - "ldr q3, [x25], #0x10\n" - "ldr q23, [x24], #0x10\n" - "zip2 v17.4s, v28.4s, v11.4s\n" - "zip2 v27.4s, v27.4s, v5.4s\n" - "ldr q5, [x23], #0x10\n" - "ldr q30, [x22], #0x10\n" - "zip1 v26.4s, v14.4s, v2.4s\n" - "zip1 v31.4s, v6.4s, v18.4s\n" - "ldr q20, [x21], #0x10\n" - "ldr q16, [x20], #0x10\n" - "zip2 v12.4s, v14.4s, v2.4s\n" - "zip2 v24.4s, v6.4s, v18.4s\n" - "ldr q29, [x9], #0x10\n" - "ldr q6, [x26], #0x10\n" - "zip1 v18.4s, v1.4s, v3.4s\n" - "zip1 v4.4s, v7.4s, v23.4s\n" - "ldr q22, [x25], #0x10\n" - "ldr q0, [x24], #0x10\n" - "zip2 v3.4s, v1.4s, v3.4s\n" - "zip2 v1.4s, v7.4s, v23.4s\n" - "ldr q2, [x23], #0x10\n" - "ldr q10, [x22], #0x10\n" - "zip1 v28.4s, v5.4s, v20.4s\n" - "zip1 v14.4s, v30.4s, v16.4s\n" - "ldr q9, [x21], #0x10\n" - "ldr q23, [x20], #0x10\n" - "zip2 v13.4s, v5.4s, v20.4s\n" - "zip2 v30.4s, v30.4s, v16.4s\n" - "zip1 v16.4s, v29.4s, v22.4s\n" - "zip1 v5.4s, v6.4s, v0.4s\n" - "zip2 v22.4s, v29.4s, v22.4s\n" - "zip2 v0.4s, v6.4s, v0.4s\n" - "zip1 v7.4s, v2.4s, v9.4s\n" - "zip1 v19.4s, v10.4s, v23.4s\n" - "zip2 v21.4s, v2.4s, v9.4s\n" - "zip2 v25.4s, v10.4s, v23.4s\n" - "zip1 v11.4s, v15.4s, v8.4s\n" - "zip1 v9.4s, v17.4s, v27.4s\n" - "zip1 v6.4s, v18.4s, v4.4s\n" - "zip1 v2.4s, v3.4s, v1.4s\n" - "zip1 v29.4s, v16.4s, v5.4s\n" - "zip1 v20.4s, v22.4s, v0.4s\n" - "zip1 v10.4s, v26.4s, v31.4s\n" - "zip1 v23.4s, v12.4s, v24.4s\n" - ".inst 0x0ea1696b // bfcvtn v11.4h, v11.4s\n" - "zip2 v8.4s, v15.4s, v8.4s\n" - "zip1 v15.4s, v28.4s, v14.4s\n" - ".inst 0x0ea16929 // bfcvtn v9.4h, v9.4s\n" - "zip2 v27.4s, v17.4s, v27.4s\n" - "zip1 v17.4s, v13.4s, v30.4s\n" - ".inst 0x0ea168c6 // bfcvtn v6.4h, v6.4s\n" - "zip2 v4.4s, v18.4s, v4.4s\n" - "zip1 v18.4s, v7.4s, v19.4s\n" - ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n" - "zip2 v1.4s, v3.4s, v1.4s\n" - "zip1 v3.4s, v21.4s, v25.4s\n" - ".inst 0x0ea16bbd // bfcvtn v29.4h, v29.4s\n" - "zip2 v5.4s, v16.4s, v5.4s\n" - ".inst 0x0ea16a94 // bfcvtn v20.4h, v20.4s\n" - "zip2 v0.4s, v22.4s, v0.4s\n" - ".inst 0x0ea16956 // bfcvtn v22.4h, v10.4s\n" - "zip2 v31.4s, v26.4s, v31.4s\n" - ".inst 0x0ea16aea // bfcvtn v10.4h, v23.4s\n" - "zip2 v26.4s, v12.4s, v24.4s\n" - ".inst 0x0ea169ef // bfcvtn v15.4h, v15.4s\n" - "zip2 v12.4s, v28.4s, v14.4s\n" - ".inst 0x0ea16a2e // bfcvtn v14.4h, v17.4s\n" - "zip2 v24.4s, v13.4s, v30.4s\n" - ".inst 0x0ea16a57 // bfcvtn v23.4h, v18.4s\n" - "zip2 v18.4s, v7.4s, v19.4s\n" - ".inst 0x0ea16871 // bfcvtn v17.4h, v3.4s\n" - "zip2 v16.4s, v21.4s, v25.4s\n" - ".inst 0x4ea1690b // bfcvtn2 v11.8h, v8.4s\n" - ".inst 0x4ea16b69 // bfcvtn2 v9.8h, v27.4s\n" - ".inst 0x4ea16886 // bfcvtn2 v6.8h, v4.4s\n" - ".inst 0x4ea16822 // bfcvtn2 v2.8h, v1.4s\n" - ".inst 0x4ea168bd // bfcvtn2 v29.8h, v5.4s\n" - ".inst 0x4ea16814 // bfcvtn2 v20.8h, v0.4s\n" - ".inst 0x4ea16bf6 // bfcvtn2 v22.8h, v31.4s\n" - ".inst 0x4ea16b4a // bfcvtn2 v10.8h, v26.4s\n" - "str q11, [x27, #0x0]\n" - ".inst 0x4ea1698f // bfcvtn2 v15.8h, v12.4s\n" - ".inst 0x4ea16b0e // bfcvtn2 v14.8h, v24.4s\n" - "str q9, [x27, #0x10]\n" - ".inst 0x4ea16a57 // bfcvtn2 v23.8h, v18.4s\n" - ".inst 0x4ea16a11 // bfcvtn2 v17.8h, v16.4s\n" - "str q6, [x27, #0x20]\n" - "str q2, [x27, #0x30]\n" - "str q29, [x27, #0x40]\n" - "str q20, [x27, #0x50]\n" - "str q22, [x27, #0x60]\n" - "str q10, [x27, #0x70]\n" - "str q15, [x27, #0x80]\n" - "str q14, [x27, #0x90]\n" - "str q23, [x27, #0xa0]\n" - "str q17, [x27, #0xb0]\n" - "add x27, x27, %x[out_stride]\n" - "bge 5b\n" - "6:" // Main row loop: Column loop skip - "cbz x28, 11f\n" - "cmp x28, #0x4\n" - "movi v16.16b, #0x0\n" - "str q16, [x27, #0x0]\n" - "str q16, [x27, #0x10]\n" - "str q16, [x27, #0x20]\n" - "str q16, [x27, #0x30]\n" - "str q16, [x27, #0x40]\n" - "str q16, [x27, #0x50]\n" - "str q16, [x27, #0x60]\n" - "str q16, [x27, #0x70]\n" - "str q16, [x27, #0x80]\n" - "str q16, [x27, #0x90]\n" - "str q16, [x27, #0xa0]\n" - "str q16, [x27, #0xb0]\n" - "blt 8f\n" - "7:" // Main row loop: width 4 loop: loop - "ldr q25, [x9], #0x10\n" - "ldr q24, [x26], #0x10\n" - "sub x28, x28, #0x4\n" - "ldr q21, [x25], #0x10\n" - "ldr q20, [x24], #0x10\n" - "cmp x28, #0x4\n" - "ldr q23, [x23], #0x10\n" - "ldr q19, [x22], #0x10\n" - "ldr q18, [x21], #0x10\n" - "ldr q17, [x20], #0x10\n" - "zip1 v22.4s, v25.4s, v21.4s\n" - "zip1 v16.4s, v24.4s, v20.4s\n" - "zip2 v21.4s, v25.4s, v21.4s\n" - "zip2 v20.4s, v24.4s, v20.4s\n" - "zip1 v27.4s, v23.4s, v18.4s\n" - "zip1 v26.4s, v19.4s, v17.4s\n" - "zip2 v25.4s, v23.4s, v18.4s\n" - "zip2 v24.4s, v19.4s, v17.4s\n" - "zip1 v19.4s, v22.4s, v16.4s\n" - "zip1 v18.4s, v21.4s, v20.4s\n" - "zip1 v17.4s, v27.4s, v26.4s\n" - "zip2 v23.4s, v22.4s, v16.4s\n" - "zip1 v16.4s, v25.4s, v24.4s\n" - "zip2 v22.4s, v21.4s, v20.4s\n" - ".inst 0x0ea16a75 // bfcvtn v21.4h, v19.4s\n" - ".inst 0x0ea16a54 // bfcvtn v20.4h, v18.4s\n" - ".inst 0x0ea16a33 // bfcvtn v19.4h, v17.4s\n" - "zip2 v18.4s, v27.4s, v26.4s\n" - ".inst 0x0ea16a11 // bfcvtn v17.4h, v16.4s\n" - "zip2 v16.4s, v25.4s, v24.4s\n" - ".inst 0x4ea16af5 // bfcvtn2 v21.8h, v23.4s\n" - ".inst 0x4ea16ad4 // bfcvtn2 v20.8h, v22.4s\n" - ".inst 0x4ea16a53 // bfcvtn2 v19.8h, v18.4s\n" - ".inst 0x4ea16a11 // bfcvtn2 v17.8h, v16.4s\n" - "str q21, [x27, #0x0]\n" - "str q20, [x27, #0x10]\n" - "str q19, [x27, #0x60]\n" - "str q17, [x27, #0x70]\n" - "add x27, x27, #0x20\n" - "bge 7b\n" - "8:" // Main row loop: width 4 loop: skip - "cmp x28, #0x1\n" - "blt 10f\n" - "9:" // Main row loop: width 1 loop: loop - "ldr s23, [x9], #0x4\n" - "ldr s22, [x26], #0x4\n" - "sub x28, x28, #0x1\n" - "ldr s19, [x25], #0x4\n" - "ldr s17, [x24], #0x4\n" - "cmp x28, #0x1\n" - "ldr s21, [x23], #0x4\n" - "ldr s20, [x22], #0x4\n" - "ldr s18, [x21], #0x4\n" - "ldr s16, [x20], #0x4\n" - "zip1 v19.4s, v23.4s, v19.4s\n" - "zip1 v17.4s, v22.4s, v17.4s\n" - "zip1 v18.4s, v21.4s, v18.4s\n" - "zip1 v16.4s, v20.4s, v16.4s\n" - "zip1 v17.4s, v19.4s, v17.4s\n" - "zip1 v16.4s, v18.4s, v16.4s\n" - ".inst 0x0ea16a31 // bfcvtn v17.4h, v17.4s\n" - ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n" - "str d17, [x27, #0x0]\n" - "str d16, [x27, #0x60]\n" - "add x27, x27, #0x8\n" - "bge 9b\n" - "10:" // Main row loop: width 1 loop: skip - "11:" // Main row loop: odd col skip - "cmp %x[height], #0x8\n" - "add %x[out], %x[out], #0xc0\n" - "bge 4b\n" - "cbz %x[height], 21f\n" - "12:" // Main loop skip - "13:" // Tail row loop: Head - "mov x9, %x[in]\n" - "mov x20, %x[width]\n" - "cmp %x[height], #0x3\n" - "mov x27, %x[out]\n" - "add x26, x9, %x[in_stride]\n" - "add x25, x26, %x[in_stride]\n" - "add x24, x25, %x[in_stride]\n" - "csel x25, x25, %x[pad_row], GE\n" - "add %x[in], x24, %x[in_stride]\n" - "csel x24, x24, %x[pad_row], GT\n" - "cmp %x[height], #0x1\n" - "sub %x[height], %x[height], #0x4\n" - "csel x26, x26, %x[pad_row], GT\n" - "cmp x20, #0xc\n" - "blt 15f\n" - "14:" // Tail row loop: Column loop - "ldr q24, [x9], #0x10\n" - "ldr q23, [x26], #0x10\n" - "sub x20, x20, #0xc\n" - "ldr q22, [x25], #0x10\n" - "ldr q16, [x24], #0x10\n" - "cmp x20, #0xc\n" - "ldr q28, [x9], #0x10\n" - "ldr q27, [x26], #0x10\n" - "ldr q21, [x25], #0x10\n" - "ldr q20, [x24], #0x10\n" - "ldr q19, [x9], #0x10\n" - "zip1 v26.4s, v24.4s, v22.4s\n" - "zip1 v25.4s, v23.4s, v16.4s\n" - "ldr q18, [x26], #0x10\n" - "ldr q17, [x25], #0x10\n" - "zip2 v24.4s, v24.4s, v22.4s\n" - "zip2 v23.4s, v23.4s, v16.4s\n" - "ldr q16, [x24], #0x10\n" - "zip1 v2.4s, v28.4s, v21.4s\n" - "zip1 v22.4s, v27.4s, v20.4s\n" - "zip2 v1.4s, v28.4s, v21.4s\n" - "zip2 v0.4s, v27.4s, v20.4s\n" - "zip1 v31.4s, v19.4s, v17.4s\n" - "zip1 v30.4s, v18.4s, v16.4s\n" - "zip2 v29.4s, v19.4s, v17.4s\n" - "zip2 v28.4s, v18.4s, v16.4s\n" - "zip1 v21.4s, v26.4s, v25.4s\n" - "zip1 v20.4s, v24.4s, v23.4s\n" - "zip1 v19.4s, v2.4s, v22.4s\n" - "zip1 v18.4s, v1.4s, v0.4s\n" - "zip1 v17.4s, v31.4s, v30.4s\n" - "zip1 v16.4s, v29.4s, v28.4s\n" - ".inst 0x0ea16abb // bfcvtn v27.4h, v21.4s\n" - "zip2 v26.4s, v26.4s, v25.4s\n" - ".inst 0x0ea16a99 // bfcvtn v25.4h, v20.4s\n" - "zip2 v24.4s, v24.4s, v23.4s\n" - ".inst 0x0ea16a77 // bfcvtn v23.4h, v19.4s\n" - "zip2 v22.4s, v2.4s, v22.4s\n" - ".inst 0x0ea16a55 // bfcvtn v21.4h, v18.4s\n" - "zip2 v20.4s, v1.4s, v0.4s\n" - ".inst 0x0ea16a33 // bfcvtn v19.4h, v17.4s\n" - "zip2 v18.4s, v31.4s, v30.4s\n" - ".inst 0x0ea16a11 // bfcvtn v17.4h, v16.4s\n" - "zip2 v16.4s, v29.4s, v28.4s\n" - ".inst 0x4ea16b5b // bfcvtn2 v27.8h, v26.4s\n" - ".inst 0x4ea16b19 // bfcvtn2 v25.8h, v24.4s\n" - ".inst 0x4ea16ad7 // bfcvtn2 v23.8h, v22.4s\n" - ".inst 0x4ea16a95 // bfcvtn2 v21.8h, v20.4s\n" - ".inst 0x4ea16a53 // bfcvtn2 v19.8h, v18.4s\n" - ".inst 0x4ea16a11 // bfcvtn2 v17.8h, v16.4s\n" - "str q27, [x27, #0x0]\n" - "str q25, [x27, #0x10]\n" - "str q23, [x27, #0x20]\n" - "str q21, [x27, #0x30]\n" - "str q19, [x27, #0x40]\n" - "str q17, [x27, #0x50]\n" - "add x27, x27, %x[out_stride]\n" - "bge 14b\n" - "15:" // Tail row loop: Column loop skip - "cbz x20, 20f\n" - "cmp x20, #0x4\n" - "movi v16.16b, #0x0\n" - "str q16, [x27, #0x0]\n" - "str q16, [x27, #0x10]\n" - "str q16, [x27, #0x20]\n" - "str q16, [x27, #0x30]\n" - "str q16, [x27, #0x40]\n" - "str q16, [x27, #0x50]\n" - "blt 17f\n" - "16:" // Tail row loop: width 4 loop: loop - "ldr q21, [x9], #0x10\n" - "ldr q20, [x26], #0x10\n" - "sub x20, x20, #0x4\n" - "ldr q19, [x25], #0x10\n" - "ldr q17, [x24], #0x10\n" - "cmp x20, #0x4\n" - "zip1 v18.4s, v21.4s, v19.4s\n" - "zip1 v16.4s, v20.4s, v17.4s\n" - "zip2 v21.4s, v21.4s, v19.4s\n" - "zip2 v20.4s, v20.4s, v17.4s\n" - "zip1 v17.4s, v18.4s, v16.4s\n" - "zip2 v19.4s, v18.4s, v16.4s\n" - "zip1 v16.4s, v21.4s, v20.4s\n" - ".inst 0x0ea16a32 // bfcvtn v18.4h, v17.4s\n" - "zip2 v17.4s, v21.4s, v20.4s\n" - ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n" - ".inst 0x4ea16a72 // bfcvtn2 v18.8h, v19.4s\n" - ".inst 0x4ea16a30 // bfcvtn2 v16.8h, v17.4s\n" - "str q18, [x27, #0x0]\n" - "str q16, [x27, #0x10]\n" - "add x27, x27, #0x20\n" - "bge 16b\n" - "17:" // Tail row loop: width 4 loop: skip - "cmp x20, #0x1\n" - "blt 19f\n" - "18:" // Tail row loop: width 1 loop: loop - "ldr s19, [x9], #0x4\n" - "ldr s18, [x26], #0x4\n" - "sub x20, x20, #0x1\n" - "ldr s17, [x25], #0x4\n" - "ldr s16, [x24], #0x4\n" - "cmp x20, #0x1\n" - "zip1 v17.4s, v19.4s, v17.4s\n" - "zip1 v16.4s, v18.4s, v16.4s\n" - "zip1 v16.4s, v17.4s, v16.4s\n" - ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n" - "str d16, [x27, #0x0]\n" - "add x27, x27, #0x8\n" - "bge 18b\n" - "19:" // Tail row loop: width 1 loop: skip - "20:" // Tail row loop: odd col skip - "cmp %x[height], #0x1\n" - "add %x[out], %x[out], #0x60\n" - "bge 13b\n" - "21:" // Done - : [bias] "+&r" (bias), [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) - : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) - : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" - ); -} - -#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon.c b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon.c new file mode 100644 index 00000000..943067e5 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon.c @@ -0,0 +1,461 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if !defined(__aarch64__) +#error This file must be compiled for AArch64. +#else // Architectural features check. + +#include +#include +#include +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_nr = 12; +static const size_t kai_kr = 4; + +size_t kai_get_n_step_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon(void) { + return kai_nr; +} + +size_t kai_get_rhs_offset_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon(size_t n_idx) { + KAI_ASSUME(n_idx % kai_nr == 0); + + return n_idx * sizeof(float); +} + +size_t kai_get_bias_offset_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon(size_t n_idx) { + return n_idx * sizeof(uint32_t); +} + +size_t kai_get_rhs_packed_offset_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon(size_t n_idx, size_t k) { + KAI_ASSUME(n_idx % kai_nr == 0); + + return n_idx * (sizeof(uint32_t) + kai_roundup(k, kai_kr) * sizeof(bfloat16_t)); +} + +size_t kai_get_rhs_packed_size_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon(size_t n, size_t k) { + return kai_get_rhs_packed_offset_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon(kai_roundup(n, kai_nr), k); +} + +void kai_run_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, + const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params) { + KAI_ASSUME(num_groups == 1); + KAI_ASSUME(nr == kai_nr); + KAI_ASSUME(kr == kai_kr); + KAI_ASSUME(sr == 1); + KAI_ASSUME(rhs != NULL); + KAI_ASSUME(bias != NULL); + KAI_ASSUME(scale == NULL); + KAI_ASSUME(rhs_packed != NULL); + KAI_ASSUME(extra_bytes == 0); + KAI_ASSUME(params == NULL); + + size_t height = k; + const size_t width = n; + const void* in = rhs; + void* out = rhs_packed; + const size_t in_stride = rhs_stride; + float* pad_row = (float*)alloca(width * sizeof(float)); + + if (height % 4) { + memset(pad_row, 0, width * sizeof(float)); + } + + size_t out_stride = kai_nr * kai_roundup(height, 4) * sizeof(bfloat16_t) + kai_nr * sizeof(uint32_t); + + __asm__ __volatile__( + "mov x22, %x[width]\n" + "mov x21, %x[out]\n" + "cmp x22, #0xc\n" + "blt 2f\n" + "1:" // Bias: Full loop + "ldr q16, [%x[bias], #0x0]\n" + "ldr q26, [%x[bias], #0x10]\n" + "sub x22, x22, #0xc\n" + "ldr q8, [%x[bias], #0x20]\n" + "cmp x22, #0xc\n" + "add %x[bias], %x[bias], #0x30\n" + "str q16, [x21, #0x0]\n" + "str q26, [x21, #0x10]\n" + "str q8, [x21, #0x20]\n" + "add x21, x21, %x[out_stride]\n" + "bge 1b\n" + "cbz x22, 3f\n" + "2:" // Bias: Tail loop + "ldr w20, [%x[bias], #0x0]\n" + "sub x22, x22, #0x1\n" + "add %x[bias], %x[bias], #0x4\n" + "cmp x22, #0x0\n" + "str w20, [x21]\n" + "add x21, x21, #0x4\n" + "bgt 2b\n" + "3:" // Bias: Done + "cmp %x[height], #0x8\n" + "add %x[out], %x[out], #0x30\n" + "blt 12f\n" + "4:" // Main row loop: Head + "mov x9, %x[in]\n" + "mov x28, %x[width]\n" + "mov x27, %x[out]\n" + "sub %x[height], %x[height], #0x8\n" + "add x26, x9, %x[in_stride]\n" + "add x25, x26, %x[in_stride]\n" + "add x24, x25, %x[in_stride]\n" + "cmp x28, #0xc\n" + "add x23, x24, %x[in_stride]\n" + "add x22, x23, %x[in_stride]\n" + "add x21, x22, %x[in_stride]\n" + "add x20, x21, %x[in_stride]\n" + "add %x[in], x20, %x[in_stride]\n" + "blt 6f\n" + "5:" // Main row loop: Column loop + "ldr q28, [x9], #0x10\n" + "ldr q27, [x26], #0x10\n" + "sub x28, x28, #0xc\n" + "ldr q11, [x25], #0x10\n" + "ldr q5, [x24], #0x10\n" + "cmp x28, #0xc\n" + "ldr q14, [x23], #0x10\n" + "ldr q6, [x22], #0x10\n" + "ldr q2, [x21], #0x10\n" + "ldr q18, [x20], #0x10\n" + "ldr q1, [x9], #0x10\n" + "ldr q7, [x26], #0x10\n" + "zip1 v15.4s, v28.4s, v11.4s\n" + "zip1 v8.4s, v27.4s, v5.4s\n" + "ldr q3, [x25], #0x10\n" + "ldr q23, [x24], #0x10\n" + "zip2 v17.4s, v28.4s, v11.4s\n" + "zip2 v27.4s, v27.4s, v5.4s\n" + "ldr q5, [x23], #0x10\n" + "ldr q30, [x22], #0x10\n" + "zip1 v26.4s, v14.4s, v2.4s\n" + "zip1 v31.4s, v6.4s, v18.4s\n" + "ldr q20, [x21], #0x10\n" + "ldr q16, [x20], #0x10\n" + "zip2 v12.4s, v14.4s, v2.4s\n" + "zip2 v24.4s, v6.4s, v18.4s\n" + "ldr q29, [x9], #0x10\n" + "ldr q6, [x26], #0x10\n" + "zip1 v18.4s, v1.4s, v3.4s\n" + "zip1 v4.4s, v7.4s, v23.4s\n" + "ldr q22, [x25], #0x10\n" + "ldr q0, [x24], #0x10\n" + "zip2 v3.4s, v1.4s, v3.4s\n" + "zip2 v1.4s, v7.4s, v23.4s\n" + "ldr q2, [x23], #0x10\n" + "ldr q10, [x22], #0x10\n" + "zip1 v28.4s, v5.4s, v20.4s\n" + "zip1 v14.4s, v30.4s, v16.4s\n" + "ldr q9, [x21], #0x10\n" + "ldr q23, [x20], #0x10\n" + "zip2 v13.4s, v5.4s, v20.4s\n" + "zip2 v30.4s, v30.4s, v16.4s\n" + "zip1 v16.4s, v29.4s, v22.4s\n" + "zip1 v5.4s, v6.4s, v0.4s\n" + "zip2 v22.4s, v29.4s, v22.4s\n" + "zip2 v0.4s, v6.4s, v0.4s\n" + "zip1 v7.4s, v2.4s, v9.4s\n" + "zip1 v19.4s, v10.4s, v23.4s\n" + "zip2 v21.4s, v2.4s, v9.4s\n" + "zip2 v25.4s, v10.4s, v23.4s\n" + "zip1 v11.4s, v15.4s, v8.4s\n" + "zip1 v9.4s, v17.4s, v27.4s\n" + "zip1 v6.4s, v18.4s, v4.4s\n" + "zip1 v2.4s, v3.4s, v1.4s\n" + "zip1 v29.4s, v16.4s, v5.4s\n" + "zip1 v20.4s, v22.4s, v0.4s\n" + "zip1 v10.4s, v26.4s, v31.4s\n" + "zip1 v23.4s, v12.4s, v24.4s\n" + ".inst 0x0ea1696b // bfcvtn v11.4h, v11.4s\n" + "zip2 v8.4s, v15.4s, v8.4s\n" + "zip1 v15.4s, v28.4s, v14.4s\n" + ".inst 0x0ea16929 // bfcvtn v9.4h, v9.4s\n" + "zip2 v27.4s, v17.4s, v27.4s\n" + "zip1 v17.4s, v13.4s, v30.4s\n" + ".inst 0x0ea168c6 // bfcvtn v6.4h, v6.4s\n" + "zip2 v4.4s, v18.4s, v4.4s\n" + "zip1 v18.4s, v7.4s, v19.4s\n" + ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n" + "zip2 v1.4s, v3.4s, v1.4s\n" + "zip1 v3.4s, v21.4s, v25.4s\n" + ".inst 0x0ea16bbd // bfcvtn v29.4h, v29.4s\n" + "zip2 v5.4s, v16.4s, v5.4s\n" + ".inst 0x0ea16a94 // bfcvtn v20.4h, v20.4s\n" + "zip2 v0.4s, v22.4s, v0.4s\n" + ".inst 0x0ea16956 // bfcvtn v22.4h, v10.4s\n" + "zip2 v31.4s, v26.4s, v31.4s\n" + ".inst 0x0ea16aea // bfcvtn v10.4h, v23.4s\n" + "zip2 v26.4s, v12.4s, v24.4s\n" + ".inst 0x0ea169ef // bfcvtn v15.4h, v15.4s\n" + "zip2 v12.4s, v28.4s, v14.4s\n" + ".inst 0x0ea16a2e // bfcvtn v14.4h, v17.4s\n" + "zip2 v24.4s, v13.4s, v30.4s\n" + ".inst 0x0ea16a57 // bfcvtn v23.4h, v18.4s\n" + "zip2 v18.4s, v7.4s, v19.4s\n" + ".inst 0x0ea16871 // bfcvtn v17.4h, v3.4s\n" + "zip2 v16.4s, v21.4s, v25.4s\n" + ".inst 0x4ea1690b // bfcvtn2 v11.8h, v8.4s\n" + ".inst 0x4ea16b69 // bfcvtn2 v9.8h, v27.4s\n" + ".inst 0x4ea16886 // bfcvtn2 v6.8h, v4.4s\n" + ".inst 0x4ea16822 // bfcvtn2 v2.8h, v1.4s\n" + ".inst 0x4ea168bd // bfcvtn2 v29.8h, v5.4s\n" + ".inst 0x4ea16814 // bfcvtn2 v20.8h, v0.4s\n" + ".inst 0x4ea16bf6 // bfcvtn2 v22.8h, v31.4s\n" + ".inst 0x4ea16b4a // bfcvtn2 v10.8h, v26.4s\n" + "str q11, [x27, #0x0]\n" + ".inst 0x4ea1698f // bfcvtn2 v15.8h, v12.4s\n" + ".inst 0x4ea16b0e // bfcvtn2 v14.8h, v24.4s\n" + "str q9, [x27, #0x10]\n" + ".inst 0x4ea16a57 // bfcvtn2 v23.8h, v18.4s\n" + ".inst 0x4ea16a11 // bfcvtn2 v17.8h, v16.4s\n" + "str q6, [x27, #0x20]\n" + "str q2, [x27, #0x30]\n" + "str q29, [x27, #0x40]\n" + "str q20, [x27, #0x50]\n" + "str q22, [x27, #0x60]\n" + "str q10, [x27, #0x70]\n" + "str q15, [x27, #0x80]\n" + "str q14, [x27, #0x90]\n" + "str q23, [x27, #0xa0]\n" + "str q17, [x27, #0xb0]\n" + "add x27, x27, %x[out_stride]\n" + "bge 5b\n" + "6:" // Main row loop: Column loop skip + "cbz x28, 11f\n" + "cmp x28, #0x4\n" + "movi v16.16b, #0x0\n" + "str q16, [x27, #0x0]\n" + "str q16, [x27, #0x10]\n" + "str q16, [x27, #0x20]\n" + "str q16, [x27, #0x30]\n" + "str q16, [x27, #0x40]\n" + "str q16, [x27, #0x50]\n" + "str q16, [x27, #0x60]\n" + "str q16, [x27, #0x70]\n" + "str q16, [x27, #0x80]\n" + "str q16, [x27, #0x90]\n" + "str q16, [x27, #0xa0]\n" + "str q16, [x27, #0xb0]\n" + "blt 8f\n" + "7:" // Main row loop: width 4 loop: loop + "ldr q25, [x9], #0x10\n" + "ldr q24, [x26], #0x10\n" + "sub x28, x28, #0x4\n" + "ldr q21, [x25], #0x10\n" + "ldr q20, [x24], #0x10\n" + "cmp x28, #0x4\n" + "ldr q23, [x23], #0x10\n" + "ldr q19, [x22], #0x10\n" + "ldr q18, [x21], #0x10\n" + "ldr q17, [x20], #0x10\n" + "zip1 v22.4s, v25.4s, v21.4s\n" + "zip1 v16.4s, v24.4s, v20.4s\n" + "zip2 v21.4s, v25.4s, v21.4s\n" + "zip2 v20.4s, v24.4s, v20.4s\n" + "zip1 v27.4s, v23.4s, v18.4s\n" + "zip1 v26.4s, v19.4s, v17.4s\n" + "zip2 v25.4s, v23.4s, v18.4s\n" + "zip2 v24.4s, v19.4s, v17.4s\n" + "zip1 v19.4s, v22.4s, v16.4s\n" + "zip1 v18.4s, v21.4s, v20.4s\n" + "zip1 v17.4s, v27.4s, v26.4s\n" + "zip2 v23.4s, v22.4s, v16.4s\n" + "zip1 v16.4s, v25.4s, v24.4s\n" + "zip2 v22.4s, v21.4s, v20.4s\n" + ".inst 0x0ea16a75 // bfcvtn v21.4h, v19.4s\n" + ".inst 0x0ea16a54 // bfcvtn v20.4h, v18.4s\n" + ".inst 0x0ea16a33 // bfcvtn v19.4h, v17.4s\n" + "zip2 v18.4s, v27.4s, v26.4s\n" + ".inst 0x0ea16a11 // bfcvtn v17.4h, v16.4s\n" + "zip2 v16.4s, v25.4s, v24.4s\n" + ".inst 0x4ea16af5 // bfcvtn2 v21.8h, v23.4s\n" + ".inst 0x4ea16ad4 // bfcvtn2 v20.8h, v22.4s\n" + ".inst 0x4ea16a53 // bfcvtn2 v19.8h, v18.4s\n" + ".inst 0x4ea16a11 // bfcvtn2 v17.8h, v16.4s\n" + "str q21, [x27, #0x0]\n" + "str q20, [x27, #0x10]\n" + "str q19, [x27, #0x60]\n" + "str q17, [x27, #0x70]\n" + "add x27, x27, #0x20\n" + "bge 7b\n" + "8:" // Main row loop: width 4 loop: skip + "cmp x28, #0x1\n" + "blt 10f\n" + "9:" // Main row loop: width 1 loop: loop + "ldr s23, [x9], #0x4\n" + "ldr s22, [x26], #0x4\n" + "sub x28, x28, #0x1\n" + "ldr s19, [x25], #0x4\n" + "ldr s17, [x24], #0x4\n" + "cmp x28, #0x1\n" + "ldr s21, [x23], #0x4\n" + "ldr s20, [x22], #0x4\n" + "ldr s18, [x21], #0x4\n" + "ldr s16, [x20], #0x4\n" + "zip1 v19.4s, v23.4s, v19.4s\n" + "zip1 v17.4s, v22.4s, v17.4s\n" + "zip1 v18.4s, v21.4s, v18.4s\n" + "zip1 v16.4s, v20.4s, v16.4s\n" + "zip1 v17.4s, v19.4s, v17.4s\n" + "zip1 v16.4s, v18.4s, v16.4s\n" + ".inst 0x0ea16a31 // bfcvtn v17.4h, v17.4s\n" + ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n" + "str d17, [x27, #0x0]\n" + "str d16, [x27, #0x60]\n" + "add x27, x27, #0x8\n" + "bge 9b\n" + "10:" // Main row loop: width 1 loop: skip + "11:" // Main row loop: odd col skip + "cmp %x[height], #0x8\n" + "add %x[out], %x[out], #0xc0\n" + "bge 4b\n" + "cbz %x[height], 21f\n" + "12:" // Main loop skip + "13:" // Tail row loop: Head + "mov x9, %x[in]\n" + "mov x20, %x[width]\n" + "cmp %x[height], #0x3\n" + "mov x27, %x[out]\n" + "add x26, x9, %x[in_stride]\n" + "add x25, x26, %x[in_stride]\n" + "add x24, x25, %x[in_stride]\n" + "csel x25, x25, %x[pad_row], GE\n" + "add %x[in], x24, %x[in_stride]\n" + "csel x24, x24, %x[pad_row], GT\n" + "cmp %x[height], #0x1\n" + "sub %x[height], %x[height], #0x4\n" + "csel x26, x26, %x[pad_row], GT\n" + "cmp x20, #0xc\n" + "blt 15f\n" + "14:" // Tail row loop: Column loop + "ldr q24, [x9], #0x10\n" + "ldr q23, [x26], #0x10\n" + "sub x20, x20, #0xc\n" + "ldr q22, [x25], #0x10\n" + "ldr q16, [x24], #0x10\n" + "cmp x20, #0xc\n" + "ldr q28, [x9], #0x10\n" + "ldr q27, [x26], #0x10\n" + "ldr q21, [x25], #0x10\n" + "ldr q20, [x24], #0x10\n" + "ldr q19, [x9], #0x10\n" + "zip1 v26.4s, v24.4s, v22.4s\n" + "zip1 v25.4s, v23.4s, v16.4s\n" + "ldr q18, [x26], #0x10\n" + "ldr q17, [x25], #0x10\n" + "zip2 v24.4s, v24.4s, v22.4s\n" + "zip2 v23.4s, v23.4s, v16.4s\n" + "ldr q16, [x24], #0x10\n" + "zip1 v2.4s, v28.4s, v21.4s\n" + "zip1 v22.4s, v27.4s, v20.4s\n" + "zip2 v1.4s, v28.4s, v21.4s\n" + "zip2 v0.4s, v27.4s, v20.4s\n" + "zip1 v31.4s, v19.4s, v17.4s\n" + "zip1 v30.4s, v18.4s, v16.4s\n" + "zip2 v29.4s, v19.4s, v17.4s\n" + "zip2 v28.4s, v18.4s, v16.4s\n" + "zip1 v21.4s, v26.4s, v25.4s\n" + "zip1 v20.4s, v24.4s, v23.4s\n" + "zip1 v19.4s, v2.4s, v22.4s\n" + "zip1 v18.4s, v1.4s, v0.4s\n" + "zip1 v17.4s, v31.4s, v30.4s\n" + "zip1 v16.4s, v29.4s, v28.4s\n" + ".inst 0x0ea16abb // bfcvtn v27.4h, v21.4s\n" + "zip2 v26.4s, v26.4s, v25.4s\n" + ".inst 0x0ea16a99 // bfcvtn v25.4h, v20.4s\n" + "zip2 v24.4s, v24.4s, v23.4s\n" + ".inst 0x0ea16a77 // bfcvtn v23.4h, v19.4s\n" + "zip2 v22.4s, v2.4s, v22.4s\n" + ".inst 0x0ea16a55 // bfcvtn v21.4h, v18.4s\n" + "zip2 v20.4s, v1.4s, v0.4s\n" + ".inst 0x0ea16a33 // bfcvtn v19.4h, v17.4s\n" + "zip2 v18.4s, v31.4s, v30.4s\n" + ".inst 0x0ea16a11 // bfcvtn v17.4h, v16.4s\n" + "zip2 v16.4s, v29.4s, v28.4s\n" + ".inst 0x4ea16b5b // bfcvtn2 v27.8h, v26.4s\n" + ".inst 0x4ea16b19 // bfcvtn2 v25.8h, v24.4s\n" + ".inst 0x4ea16ad7 // bfcvtn2 v23.8h, v22.4s\n" + ".inst 0x4ea16a95 // bfcvtn2 v21.8h, v20.4s\n" + ".inst 0x4ea16a53 // bfcvtn2 v19.8h, v18.4s\n" + ".inst 0x4ea16a11 // bfcvtn2 v17.8h, v16.4s\n" + "str q27, [x27, #0x0]\n" + "str q25, [x27, #0x10]\n" + "str q23, [x27, #0x20]\n" + "str q21, [x27, #0x30]\n" + "str q19, [x27, #0x40]\n" + "str q17, [x27, #0x50]\n" + "add x27, x27, %x[out_stride]\n" + "bge 14b\n" + "15:" // Tail row loop: Column loop skip + "cbz x20, 20f\n" + "cmp x20, #0x4\n" + "movi v16.16b, #0x0\n" + "str q16, [x27, #0x0]\n" + "str q16, [x27, #0x10]\n" + "str q16, [x27, #0x20]\n" + "str q16, [x27, #0x30]\n" + "str q16, [x27, #0x40]\n" + "str q16, [x27, #0x50]\n" + "blt 17f\n" + "16:" // Tail row loop: width 4 loop: loop + "ldr q21, [x9], #0x10\n" + "ldr q20, [x26], #0x10\n" + "sub x20, x20, #0x4\n" + "ldr q19, [x25], #0x10\n" + "ldr q17, [x24], #0x10\n" + "cmp x20, #0x4\n" + "zip1 v18.4s, v21.4s, v19.4s\n" + "zip1 v16.4s, v20.4s, v17.4s\n" + "zip2 v21.4s, v21.4s, v19.4s\n" + "zip2 v20.4s, v20.4s, v17.4s\n" + "zip1 v17.4s, v18.4s, v16.4s\n" + "zip2 v19.4s, v18.4s, v16.4s\n" + "zip1 v16.4s, v21.4s, v20.4s\n" + ".inst 0x0ea16a32 // bfcvtn v18.4h, v17.4s\n" + "zip2 v17.4s, v21.4s, v20.4s\n" + ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n" + ".inst 0x4ea16a72 // bfcvtn2 v18.8h, v19.4s\n" + ".inst 0x4ea16a30 // bfcvtn2 v16.8h, v17.4s\n" + "str q18, [x27, #0x0]\n" + "str q16, [x27, #0x10]\n" + "add x27, x27, #0x20\n" + "bge 16b\n" + "17:" // Tail row loop: width 4 loop: skip + "cmp x20, #0x1\n" + "blt 19f\n" + "18:" // Tail row loop: width 1 loop: loop + "ldr s19, [x9], #0x4\n" + "ldr s18, [x26], #0x4\n" + "sub x20, x20, #0x1\n" + "ldr s17, [x25], #0x4\n" + "ldr s16, [x24], #0x4\n" + "cmp x20, #0x1\n" + "zip1 v17.4s, v19.4s, v17.4s\n" + "zip1 v16.4s, v18.4s, v16.4s\n" + "zip1 v16.4s, v17.4s, v16.4s\n" + ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n" + "str d16, [x27, #0x0]\n" + "add x27, x27, #0x8\n" + "bge 18b\n" + "19:" // Tail row loop: width 1 loop: skip + "20:" // Tail row loop: odd col skip + "cmp %x[height], #0x1\n" + "add %x[out], %x[out], #0x60\n" + "bge 13b\n" + "21:" // Done + : [bias] "+&r"(bias), [height] "+&r"(height), [in] "+&r"(in), [out] "+&r"(out) + : [in_stride] "r"(in_stride), [out_stride] "r"(out_stride), [pad_row] "r"(pad_row), [width] "r"(width) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31", "x9", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12.h b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon.h similarity index 64% rename from kai/ukernels/matmul/pack/kai_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12.h rename to kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon.h index c12f253f..7ef90846 100644 --- a/kai/ukernels/matmul/pack/kai_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12.h +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon.h @@ -17,23 +17,21 @@ extern "C" { /// The starting row index must be divisible by `n_step`. /// /// @return The n step value. -size_t kai_get_n_step_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12(void); +size_t kai_get_n_step_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon(void); /// Gets the offset in bytes to the data element in the RHS matrix buffer. /// /// @param[in] n_idx Column index. /// /// @return The offset in bytes to the data element. -size_t kai_get_rhs_offset_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12(size_t n_idx); - +size_t kai_get_rhs_offset_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon(size_t n_idx); /// Gets the offset in bytes to the data element in the bias buffer. /// /// @param[in] n_idx Column index. /// /// @return The offset in bytes to the data element. -size_t kai_get_bias_offset_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12(size_t n_idx); - +size_t kai_get_bias_offset_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon(size_t n_idx); /// Gets the offset in bytes to the data element in the packed RHS buffer. /// @@ -41,7 +39,7 @@ size_t kai_get_bias_offset_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32 /// @param[in] k Number of columns. /// /// @return The offset in bytes to the data element. -size_t kai_get_rhs_packed_offset_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12(size_t n_idx, size_t k); +size_t kai_get_rhs_packed_offset_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon(size_t n_idx, size_t k); /// Gets the size in bytes of the packed RHS buffer. /// @@ -49,16 +47,16 @@ size_t kai_get_rhs_packed_offset_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf /// @param[in] k Number of columns. /// /// @return The size in bytes of the packed RHS buffer. -size_t kai_get_rhs_packed_size_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12(size_t n, size_t k); +size_t kai_get_rhs_packed_size_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon(size_t n, size_t k); /// Runs the RHS packing function for matrix multiplication. /// /// The pointer of each buffers (RHS, bias and packed RHS) needs to be added with offset /// calculated using the following functions: /// -/// * RHS: @ref kai_get_rhs_offset_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12. -/// * Bias: @ref kai_get_bias_offset_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12. -/// * Output: @ref kai_get_rhs_packed_offset_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12. +/// * RHS: @ref kai_get_rhs_offset_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon. +/// * Bias: @ref kai_get_bias_offset_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon. +/// * Output: @ref kai_get_rhs_packed_offset_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon. /// /// @param[in] num_groups Number of groups. It must be 1. /// @param[in] n Number of columns of the output matrix. @@ -73,20 +71,9 @@ size_t kai_get_rhs_packed_size_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16 /// @param[out] rhs_packed Packed RHS matrix. /// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. It must be 0. /// @param[in] params Extra packing parameters. It must be NULL. -void kai_run_matmul_transpose_pack_rhs_bias_bf16p16x4zf32_bf16_f32_neon_nr_12( - size_t num_groups, - size_t n, - size_t k, - size_t nr, - size_t kr, - size_t sr, - size_t rhs_stride, - const void* rhs, - const void* bias, - const void* scale, - void* rhs_packed, - size_t extra_bytes, - const void* params); +void kai_run_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, + const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params); #ifdef __cplusplus } // extern "C" diff --git a/test/common/MatMulMethod.hpp b/test/common/MatMulMethod.hpp new file mode 100644 index 00000000..daf91a7e --- /dev/null +++ b/test/common/MatMulMethod.hpp @@ -0,0 +1,330 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include +#include +#include +#include + +#include "kai/kai_common.h" +#include "test/common/data_format.hpp" +#include "test/common/float16.hpp" + +namespace kai::test { + +// NOLINTBEGIN(misc-non-private-member-variables-in-classes) + +/// Matrix multiplication method. +struct MatMulMethod { + std::string_view name; ///< Name of matmul method. + + size_t m0; ///< Block size in M dimension. + size_t n0; ///< Block size in N dimension. + + bool lhs_transposed; ///< LHS matrix is transposed. + bool rhs_transposed; ///< RHS matrix is transposed. + + bool is_sme2; ///< Test is a sme2 test + + DataFormat dst_format; ///< Data format of the destination matrix. + DataFormat lhs_format; ///< Data format of the LHS matrix. + DataFormat packed_lhs_format; ///< Data format of the packed LHS matrix. + DataFormat rhs_format; ///< Data format of the RHS matrix. + DataFormat packed_rhs_format; ///< Data format of the packed RHS matrix. + DataFormat bias_format; ///< Data format of the bias vector. + + /// Gets mr value. + /// + /// This is the packing parameter which must be used to pack the LHS matrix (if necessary). + /// + /// @return The mr value. + std::function fn_get_mr; + + /// Gets nr value. + /// + /// This is the packing parameter which must be used to pack the RHS matrix (if necessary). + /// + /// @return The nr value. + std::function fn_get_nr; + + /// Gets kr value. + /// + /// This is the packing parameter which must be used to pack the LHS and RHS matrix (if necessary). + /// + /// @return The kr value. + std::function fn_get_kr; + + /// Gets sr value. + /// + /// This is the packing parameter which must be used to pack the RHS matrix. + /// + /// @return The sr value. + std::function fn_get_sr; + + /// Gets m step value for main kernel. + /// + /// The starting row index must be divisible by `m_step`. + /// + /// @return The m step value. + std::function fn_get_main_m_step; + + /// Gets n step value for RHS packing kernel. + /// + /// The starting row index must be divisible by `n_step`. + /// + /// @return The n step value. + std::function fn_get_pack_rhs_n_step; + + /// Gets n step value for main kernel. + /// + /// The starting column index must be divisible by `n_step`. + /// + /// @return The n step value. + std::function fn_get_main_n_step; + + /// Gets the offset in bytes of the LHS matrix. + /// + /// @param[in] m_idx Coordinate of the matrix in M dimension. + /// @param[in] stride Row stride in bytes. + /// + /// @return The offset in bytes. + std::function fn_get_lhs_offset; + + /// Gets the size in bytes of the packed LHS matrix. + /// + /// @param[in] m Number of rows in the unpacked LHS matrix. + /// @param[in] k Number of columns in the unpacked LHS matrix. + /// @param[in] mr Number of rows to be interleaved. + /// @param[in] kr Unused. Must be 1. + /// @param[in] sr Unused. Must be 1. + /// + /// @return The size in bytes. + std::function fn_get_packed_lhs_size; + + /// Gets the offset in bytes of the packed LHS matrix. + /// + /// @param[in] m_idx Coordinate of the matrix in M dimension. + /// @param[in] k Size of the matrix in K dimension. + /// + /// @return The offset in bytes. + std::function fn_get_packed_lhs_offset; + + /// Preprocesses the LHS matrix. + /// + /// @param[in] m Number of rows of the unpacked LHS matrix. + /// @param[in] k Common dimension between the LHS and RHS matrix. + /// @param[in] mr Block size in M dimension. It must be {{ kernel.interleave_by }}VL. + /// @param[in] kr Block size in K dimension. It must be {{ kernel.block_by }}. + /// @param[in] sr Number of kr splits. It must be 1. + /// @param[in] m_idx_start Unused. Must be 0. + /// @param[in] lhs LHS matrix data buffer. + /// @param[in] lhs_stride Row stride in bytes of the LHS matrix. + /// @param[out] lhs_packed Packed RHS matrix. + std::function + fn_pack_lhs; + + /// Gets a value indicating whether LHS packing is needed. + [[nodiscard]] bool is_pack_lhs_needed() const { + return fn_pack_lhs != nullptr; + } + + /// Gets the offset in bytes of the RHS matrix. + /// + /// @param[in] n_idx Coordinate of the matrix in N dimension. + /// + /// @return The offset in bytes. + std::function fn_get_rhs_offset; + + /// Gets the size in bytes of the packed RHS matrix. + /// + /// @param[in] n Size of the matrix in N dimension. + /// @param[in] k Size of the matrix in K dimension. + /// + /// @return The size in bytes. + std::function fn_get_packed_rhs_size; + + /// Gets the offset in bytes of the packed RHS matrix in the RHS packing kernel + /// + /// @param[in] n_idx Coordinate of the matrix in N dimension. + /// @param[in] k Size of the matrix in K dimension. + /// + /// @return The offset in bytes. + std::function fn_get_pack_rhs_packed_rhs_offset; + + /// Gets the offset in bytes of the packed RHS matrix in the main kernel. + /// + /// @param[in] n_idx Coordinate of the matrix in N dimension. + /// @param[in] k Size of the matrix in K dimension. + /// + /// @return The offset in bytes. + std::function fn_get_main_packed_rhs_offset; + + std::function + fn_pack_rhs; + + /// Gets the offset in bytes to the data element in the bias buffer. + /// + /// @param[in] n_idx Column index. + /// + /// @return The offset in bytes to the data element. + std::function fn_get_bias_offset; + + /// Gets the offset in bytes to the data element in the destination matrix buffer. + /// + /// @param[in] m_idx Row index. + /// @param[in] n_idx Column index. + /// @param[in] stride Row stride in bytes. + /// + /// @return The offset in bytes to the data element. + std::function fn_get_dst_offset; + + /// Gets the size in bytes of the destination matrix buffer. + /// + /// @param[in] m Number of rows. + /// @param[in] n Number of columns. + /// + /// @return The size in bytes of the destination matrix buffer. + std::function fn_get_dst_size; + + /// Performs F16 or F32 matrix multiplication with RHS packing + /// followed by clamp operation. + /// + /// @param[in] m Size of the matrix in M dimension. + /// @param[in] n Size of the matrix in N dimension. + /// @param[in] k Size of the matrix in K dimension. + /// @param[in] lhs LHS data buffer. + /// @param[in] packed_rhs Packed RHS data buffer. + /// @param[out] dst Output data buffer. + /// @param[in] lhs_stride LHS row stride. + /// @param[in] dst_stride_row Output row stride. + /// @param[in] dst_stride_col Output column stride. + /// @param[in] clamp_min Lower bound of the output data. + /// @param[in] clamp_max Upper bound of the output data. + std::function + fn_matmul_f16_f16_f16p; + + std::function + fn_matmul_f32_f32_f32p; + + /// Performs BF16 matrix multiplication with LHS and RHS packing + /// followed by clamp operation. + /// + /// @param[in] m Size of the matrix in M dimension. + /// @param[in] n Size of the matrix in N dimension. + /// @param[in] k Size of the matrix in K dimension. + /// @param[in] packed_lhs Packed LHS data buffer. + /// @param[in] packed_rhs Packed RHS data buffer. + /// @param[out] dst Output data buffer. + /// @param[in] dst_stride_row Output row stride. + /// @param[in] dst_stride_col Output column stride. + /// @param[in] clamp_min Lower bound of the output data. + /// @param[in] clamp_max Upper bound of the output data. + std::function + fn_matmul_f32_bf16p_bf16p; + + /// Performs F32 matrix multiplication with LHS & RHS packing + /// followed by clamp operation. + /// + /// @param[in] m Number of output rows to be computed. + /// @param[in] n Number of output columns to be computed. + /// @param[in] k Common dimension of the LHS and RHS operands. + /// @param[in] packed_lhs Packed LHS matrix buffer. + /// @param[in] packed_rhs Packed RHS matrix buffer. + /// @param[out] dst Output matrix buffer. + /// @param[in] dst_stride_row Row stride in bytes of the output matrix. + /// @param[in] dst_stride_col Column stride in bytes of the output matrix. + /// @param[in] clamp_min Minimum value to clamp the final result. + /// @param[in] clamp_max Maximum value to clamp the final result. + std::function + fn_matmul_f32_f32p_f32p; + + /// Gets a value indicating whether pre-processing the RHS matrix is needed. + [[nodiscard]] bool is_pack_rhs_needed() const { + return fn_pack_rhs != nullptr; + } + + /// Preprocesses the RHS matrix. + /// + /// @param[in] n Size of the matrix in N dimension. + /// @param[in] k Size of the matrix in K dimension. + /// @param[in] rhs RHS data buffer. + /// @param[in] rhs_row_stride RHS row stride. + /// @param[in] bias Bias data buffer. + /// @param[in] scale Quantization scales data buffer. + /// @param[out] packed_rhs Packed RHS data buffer. + void pack_rhs( + size_t n, size_t k, const void* rhs, size_t rhs_row_stride, const void* bias, const void* scale, + void* packed_rhs) const { + KAI_UNUSED(n); + KAI_UNUSED(k); + KAI_UNUSED(rhs); + KAI_UNUSED(rhs_row_stride); + KAI_UNUSED(bias); + KAI_UNUSED(scale); + KAI_UNUSED(packed_rhs); + + if (fn_pack_rhs != nullptr) { + fn_pack_rhs( + 1, n, k, fn_get_nr(), fn_get_kr(), fn_get_sr(), rhs_row_stride, rhs, bias, nullptr, packed_rhs, 0, + nullptr); + } else { + KAI_ERROR("RHS pre-processing is not supported!"); + } + } + + [[nodiscard]] bool has_main_kernel() const { + return fn_matmul_f16_f16_f16p != nullptr || fn_matmul_f32_f32p_f32p != nullptr || + fn_matmul_f32_f32_f32p != nullptr || fn_matmul_f32_bf16p_bf16p != nullptr; + } + + void main_kernel( + size_t m, size_t n, size_t k, const void* lhs, const void* rhs, const void* bias, void* dst, size_t lhs_stride, + size_t rhs_stride, size_t dst_stride, float clamp_min, float clamp_max) const { + KAI_UNUSED(bias); + KAI_UNUSED(rhs_stride); + + if (fn_matmul_f16_f16_f16p) { + fn_matmul_f16_f16_f16p( + m, n, k, lhs, lhs_stride, rhs, dst, dst_stride, sizeof(Float16), clamp_min, + static_cast(clamp_max)); + } else if (fn_matmul_f32_f32_f32p) { + fn_matmul_f32_f32_f32p( + m, n, k, lhs, lhs_stride, rhs, dst, dst_stride, sizeof(float), clamp_min, + static_cast(clamp_max)); + } else if (fn_matmul_f32_f32p_f32p) { + fn_matmul_f32_f32p_f32p(m, n, k, lhs, rhs, dst, dst_stride, sizeof(float), clamp_min, clamp_max); + } else if (fn_matmul_f32_bf16p_bf16p) { + fn_matmul_f32_bf16p_bf16p(m, n, k, lhs, rhs, dst, dst_stride, sizeof(float), clamp_min, clamp_max); + } else { + KAI_ERROR("Main kernel is not available!"); + } + } +}; + +// NOLINTEND(misc-non-private-member-variables-in-classes) +} // namespace kai::test diff --git a/test/common/bfloat16.hpp b/test/common/bfloat16.hpp index ca0e0b37..9291918a 100644 --- a/test/common/bfloat16.hpp +++ b/test/common/bfloat16.hpp @@ -82,6 +82,14 @@ public: return _data != rhs._data; } + uint16_t data() const { + return _data; + } + + void set_data(uint16_t data) { + _data = data; + } + /// Writes the value to the output stream. /// /// @param[in] os Output stream to be written to. diff --git a/test/common/compare.cpp b/test/common/compare.cpp index 54af776f..b000f3f5 100644 --- a/test/common/compare.cpp +++ b/test/common/compare.cpp @@ -213,6 +213,9 @@ bool compare( case DataType::FP16: return compare_raw(imp_data, ref_data, format, full_height, full_width, rect, handler); + case DataType::BF16: + return compare_raw(imp_data, ref_data, format, full_height, full_width, rect, handler); + default: break; } diff --git a/test/common/matmul_test_common.cpp b/test/common/matmul_test_common.cpp new file mode 100644 index 00000000..905450fb --- /dev/null +++ b/test/common/matmul_test_common.cpp @@ -0,0 +1,25 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "matmul_test_common.hpp" + +#include +#include + +namespace kai::test { +void PrintTo(const MatMulTestParams& param, std::ostream* os) { + const auto& [method, shape, portion] = param; + + // NOLINTBEGIN(cppcoreguidelines-pro-bounds-constant-array-index) + *os << "Method_" << method.name // + << "__M_" << shape.m << "__N_" << shape.n << "__K_" << shape.k // + << "__PortionStartRow_" << static_cast(portion.start_row() * 1000) // + << "__PortionStartCol_" << static_cast(portion.start_col() * 1000) // + << "__PortionHeight_" << static_cast(portion.height() * 1000) // + << "__PortionWidth_" << static_cast(portion.width() * 1000); + // NOLINTEND(cppcoreguidelines-pro-bounds-constant-array-index) +} +} // namespace kai::test diff --git a/test/common/matmul_test_common.hpp b/test/common/matmul_test_common.hpp new file mode 100644 index 00000000..767117c5 --- /dev/null +++ b/test/common/matmul_test_common.hpp @@ -0,0 +1,26 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include + +#include "test/common/MatMulMethod.hpp" +#include "test/common/matrix_portion.hpp" + +namespace kai::test { +/// Matrix multiplication shape. +struct MatMulShape { + size_t m; ///< LHS height. + size_t n; ///< RHS width. + size_t k; ///< LHS width and RHS height. +}; + +/// Matrix multiplication test information. +using MatMulTestParams = std::tuple; + +/// Prints the test information. +void PrintTo(const MatMulTestParams& param, std::ostream* os); +} // namespace kai::test diff --git a/test/common/memory.hpp b/test/common/memory.hpp index bf5fbb01..7ea0aa53 100644 --- a/test/common/memory.hpp +++ b/test/common/memory.hpp @@ -9,6 +9,7 @@ #include #include +#include "test/common/bfloat16.hpp" #include "test/common/int4.hpp" namespace kai::test { @@ -25,6 +26,14 @@ inline constexpr size_t size_in_bits = 4; template <> inline constexpr size_t size_in_bits = 4; +/// TODO: Move this +inline float bf16_to_float(uint16_t v) { + const uint32_t lv = (v << 16); + float fp; + memcpy(&fp, &lv, sizeof(lv)); + return fp; +} + /// Reads the array at the specified index. /// /// @param[in] array Data buffer. @@ -39,6 +48,9 @@ T read_array(const void* array, size_t index) { } else if constexpr (std::is_same_v) { const auto [lo, hi] = Int4::unpack_u8(reinterpret_cast(array)[index / 2]); return index % 2 == 0 ? lo : hi; + } else if constexpr (std::is_same_v) { + uint16_t raw_value = reinterpret_cast(array)[index]; + return BFloat16(bf16_to_float(raw_value)); } else { return reinterpret_cast(array)[index]; } diff --git a/test/reference/pack.cpp b/test/reference/pack.cpp index 221ba360..5c549664 100644 --- a/test/reference/pack.cpp +++ b/test/reference/pack.cpp @@ -14,6 +14,7 @@ #include #include "kai/kai_common.h" +#include "test/common/bfloat16.hpp" #include "test/common/data_format.hpp" #include "test/common/data_type.hpp" #include "test/common/float16.hpp" @@ -25,14 +26,20 @@ namespace kai::test { namespace { +uint16_t convert(const uint8_t* src_ptr_elm, DataType src_dtype, DataType dst_dtype) { + KAI_ASSUME(src_dtype == DataType::FP32 && dst_dtype == DataType::BF16); + return BFloat16(*reinterpret_cast(src_ptr_elm)).data(); +} + std::vector pack_block( - const void* src, size_t data_esize, size_t full_height, size_t full_width, size_t block_height, size_t block_width, - size_t subblock_height, size_t subblock_width) { + const void* src, DataType src_dtype, DataType dst_dtype, size_t src_esize, size_t dst_esize, size_t full_height, + size_t full_width, size_t block_height, size_t block_width, size_t subblock_height, size_t subblock_width) { const auto dst_bytes = - round_up_multiple(full_height, block_height) * round_up_multiple(full_width, block_width) * data_esize; + round_up_multiple(full_height, block_height) * round_up_multiple(full_width, block_width) * dst_esize; std::vector dst; dst.resize(dst_bytes); + memset(dst.data(), 0, dst_bytes); const auto* src_ptr = reinterpret_cast(src); auto* dst_ptr = dst.data(); @@ -42,18 +49,38 @@ std::vector pack_block( for (size_t y_subblock = 0; y_subblock < block_height; y_subblock += subblock_height) { for (size_t x_subblock = 0; x_subblock < block_width; x_subblock += subblock_width) { for (size_t y_element = 0; y_element < subblock_height; ++y_element) { - if (y_block + y_subblock + y_element < full_height) { - const auto len = std::min(subblock_width, full_width - x_block - x_subblock); - - memcpy( - dst_ptr, - src_ptr + - ((y_block + y_subblock + y_element) * full_width + x_block + x_subblock) * - data_esize, - len * data_esize); - } + if (src_dtype == dst_dtype) { + const size_t esize = dst_esize; + + if (y_block + y_subblock + y_element < full_height) { + const auto len = std::min(subblock_width, full_width - x_block - x_subblock); + + memcpy( + dst_ptr, + src_ptr + + ((y_block + y_subblock + y_element) * full_width + x_block + x_subblock) * + esize, + len * esize); + } - dst_ptr += subblock_width * data_esize; + dst_ptr += subblock_width * esize; + } else if (dst_esize == 2 /* 16 bits */) { + for (size_t x_element = 0; x_element < subblock_width; ++x_element) { + if (y_block + y_subblock + y_element < full_height) { + if (x_block + x_subblock + x_element < full_width) { + const uint8_t* src_ptr_elm = src_ptr + + ((y_block + y_subblock + y_element) * full_width + x_block + x_subblock + + x_element) * + src_esize; + + uint16_t src_value = convert(src_ptr_elm, src_dtype, dst_dtype); + memcpy(dst_ptr, &src_value, dst_esize); + } + } + + dst_ptr += dst_esize; + } + } } } } @@ -67,43 +94,67 @@ std::vector pack_block( /// Packs the matrix from raw to per-row bias format. std::vector pack_bias_per_row( - size_t data_esize, size_t zero_point_esize, const void* src, const void* bias, size_t height, size_t width, - size_t block_height, size_t block_width, size_t subblock_height, size_t subblock_width) { + DataType src_dtype, DataType bias_dtype, DataType dst_dtype, size_t src_esize, size_t bias_esize, size_t dst_esize, + const void* src, const void* bias, size_t height, size_t width, size_t block_height, size_t block_width, + size_t subblock_height, size_t subblock_width) { + KAI_ASSUME(src_dtype == bias_dtype); + const auto num_groups = (height + block_height - 1) / block_height; const auto group_num_blocks = (width + block_width - 1) / block_width; - - const auto group_zero_points_bytes = block_height * zero_point_esize; - const auto block_data_bytes = block_height * block_width * data_esize; - const auto group_bytes = group_zero_points_bytes + group_num_blocks * block_data_bytes; + const auto group_bias_bytes = block_height * bias_esize; + const auto block_data_bytes = block_height * block_width * dst_esize; + const auto group_bytes = group_bias_bytes + group_num_blocks * block_data_bytes; const auto dst_bytes = num_groups * group_bytes; std::vector dst; dst.resize(dst_bytes); + memset(dst.data(), 0, dst_bytes); const auto* src_ptr = reinterpret_cast(src); const auto* bias_ptr = reinterpret_cast(bias); auto* dst_ptr = dst.data(); for (size_t y_block = 0; y_block < height; y_block += block_height) { - // Packs the zero points. + // Packs the bias. const auto bias_len = std::min(block_height, height - y_block); - memcpy(dst_ptr, bias_ptr, bias_len * zero_point_esize); - bias_ptr += block_height * zero_point_esize; - dst_ptr += block_height * zero_point_esize; + memcpy(dst_ptr, bias_ptr, bias_len * bias_esize); + bias_ptr += block_height * bias_esize; + dst_ptr += block_height * bias_esize; for (size_t x_block = 0; x_block < width; x_block += block_width) { for (size_t y_subblock = 0; y_subblock < block_height; y_subblock += subblock_height) { for (size_t x_subblock = 0; x_subblock < block_width; x_subblock += subblock_width) { for (size_t y_element = 0; y_element < subblock_height; ++y_element) { - if (y_block + y_subblock + y_element < height) { - const auto len = std::min(subblock_width, width - x_block - x_subblock); - memcpy( - dst_ptr, - src_ptr + - ((y_block + y_subblock + y_element) * width + x_block + x_subblock) * data_esize, - len * data_esize); + if (src_dtype == dst_dtype) { + const size_t esize = dst_esize; + if (y_block + y_subblock + y_element < height) { + const auto len = std::min(subblock_width, width - x_block - x_subblock); + + memcpy( + dst_ptr, + src_ptr + + ((y_block + y_subblock + y_element) * width + x_block + x_subblock) * esize, + len * esize); + } + + dst_ptr += subblock_width * esize; + } else if (dst_esize == 2 /* 16 bits */) { + for (size_t x_element = 0; x_element < subblock_width; ++x_element) { + if (y_block + y_subblock + y_element < height) { + if (x_block + x_subblock + x_element < width) { + const uint8_t* src_ptr_elm = src_ptr + + ((y_block + y_subblock + y_element) * width + x_block + x_subblock + + x_element) * + src_esize; + + uint16_t src_value = convert(src_ptr_elm, src_dtype, dst_dtype); + memcpy(dst_ptr, &src_value, dst_esize); + } + } + + dst_ptr += dst_esize; + } } - dst_ptr += subblock_width * data_esize; } } } @@ -118,7 +169,7 @@ std::vector pack_bias_per_row( } // namespace std::vector pack( - const DataFormat& dst_format, const void* src, [[maybe_unused]] const void* scales, const void* zero_points, + const DataFormat& dst_format, const void* src, [[maybe_unused]] const void* scales, const void* bias, const DataFormat& src_format, size_t height, size_t width) { const auto dst_dt = dst_format.data_type(); const auto dst_qf = dst_format.pack_format(); @@ -131,27 +182,31 @@ std::vector pack( const auto subblock_width = dst_format.actual_subblock_width(width); if (src_qf == DataFormat::PackFormat::NONE && dst_qf == DataFormat::PackFormat::BIAS_PER_ROW) { - KAI_ASSUME(src_dt == dst_dt); + KAI_ASSUME((src_dt == dst_dt) || (src_dt == DataType::FP32 && dst_dt == DataType::BF16)); - const auto data_esize = data_type_size_in_bits(dst_dt); - const auto zero_point_esize = data_type_size_in_bits(dst_format.zero_point_data_type()); + const auto src_esize = data_type_size_in_bits(src_dt); + const auto dst_esize = data_type_size_in_bits(dst_dt); + const auto bias_esize = data_type_size_in_bits(dst_format.zero_point_data_type()); + const auto bias_dt = dst_format.zero_point_data_type(); - if (data_esize % 8 == 0 && zero_point_esize % 8 == 0) { - return pack_bias_per_row( - data_esize / 8, zero_point_esize / 8, src, zero_points, height, width, block_height, block_width, - subblock_height, subblock_width); - } + KAI_ASSUME(dst_esize % 8 == 0 && bias_esize % 8 == 0 && src_esize % 8 == 0); + + return pack_bias_per_row( + src_dt, bias_dt, dst_dt, src_esize / 8, bias_esize / 8, dst_esize / 8, src, bias, height, width, + block_height, block_width, subblock_height, subblock_width); } if (src_qf == DataFormat::PackFormat::NONE && dst_qf == DataFormat::PackFormat::NONE) { - KAI_ASSUME(src_dt == dst_dt); + KAI_ASSUME((src_dt == dst_dt) || (src_dt == DataType::FP32 && dst_dt == DataType::BF16)); - const auto data_esize = data_type_size_in_bits(dst_dt); + const auto dst_esize = data_type_size_in_bits(dst_dt); + const auto src_esize = data_type_size_in_bits(src_dt); - if (data_esize % 8 == 0) { - return pack_block( - src, data_esize / 8, height, width, block_height, block_width, subblock_height, subblock_width); - } + KAI_ASSUME(src_esize % 8 == 0 && dst_esize % 8 == 0); + + return pack_block( + src, src_dt, dst_dt, src_esize / 8, dst_esize / 8, height, width, block_height, block_width, + subblock_height, subblock_width); } KAI_ERROR("Unsupported operation!"); diff --git a/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp new file mode 100644 index 00000000..493af592 --- /dev/null +++ b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp @@ -0,0 +1,328 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "kai/kai_common.h" +#include "test/common/MatMulMethod.hpp" +#include "test/common/compare.hpp" +#include "test/common/cpu_info.hpp" +#include "test/common/data_format.hpp" +#include "test/common/data_type.hpp" +#include "test/common/float16.hpp" +#include "test/common/matmul_test_common.hpp" +#include "test/common/matrix_portion.hpp" +#include "test/common/printer.hpp" +#include "test/common/sme.hpp" +#include "test/reference/fill.hpp" +#include "test/reference/matmul.hpp" +#include "test/reference/pack.hpp" + +// matmul_clamp_f32_bf16p_bf16p +#include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.h" +#include "kai/ukernels/matmul/pack/kai_lhs_pack_f32p8x4_bf16_neon.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon.h" +namespace kai::test { + +/// List of supported matrix multiplication methods. +static const std::array matmul_methods = { + MatMulMethod{ + .name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla", + + .m0 = 8, + .n0 = 12, + + .lhs_transposed = false, + .rhs_transposed = false, + + .is_sme2 = false, + + .dst_format = DataFormat(DataType::FP32), + .lhs_format = DataFormat(DataType::FP32), + .packed_lhs_format = + DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP32, DataType::UNKNOWN, 8, 4), + .rhs_format = DataFormat(DataType::FP32), + .packed_rhs_format = DataFormat( + DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4), + .bias_format = DataFormat(DataType::FP32), + + .fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + .fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + .fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + .fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + + .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon, + .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + + .fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_f32p8x4_bf16_neon, + .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_f32p8x4_bf16_neon, + .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_lhs_pack_f32p8x4_bf16_neon, + .fn_pack_lhs = kai_run_lhs_pack_f32p8x4_bf16_neon, + + .fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon, + .fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon, + .fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon, + .fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + .fn_pack_rhs = kai_run_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon, + + .fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon, + + .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + + .fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + }, +}; + +/// Matrix multiplication test fixture. +class MatMulTestBf16 : public testing::TestWithParam { +private: + /// Unique ID: m, n, k + using TestDataId = std::tuple; + +protected: + /// Cached test data that is shared between multiple test case. + struct TestData { + std::vector lhs{}; ///< LHS operand. + std::vector ref_packed_lhs{}; ///< Reference packed LHS. + std::vector rhs{}; ///< RHS operand. + std::vector rhs_scales{}; ///< RHS per-row quantization scales. + std::vector bias{}; ///< Bias. + std::vector ref_packed_rhs{}; ///< Reference packed RHS. + std::vector ref_dst{}; ///< Reference output. + }; + + /// Gets the test data for the current test case. + static const TestData& test_data() { + const auto& [method, info, portion] = GetParam(); + const TestDataId data_id{info.m, info.n, info.k, method.name}; + + // If the test data is already available, returns it. + const auto data_it = _data.find(data_id); + + if (data_it != _data.end()) { + return data_it->second; + } + + // Generates the test data. + const auto has_lhs_pack = method.packed_lhs_format.data_type() != DataType::UNKNOWN; + const auto has_rhs_pack = method.packed_rhs_format.data_type() != DataType::UNKNOWN; + const auto has_bias = method.bias_format.data_type() != DataType::UNKNOWN; + + const auto lhs_h = method.lhs_transposed ? info.k : info.m; + const auto lhs_w = method.lhs_transposed ? info.m : info.k; + auto lhs = fill_matrix_random(lhs_h, lhs_w, method.lhs_format, 0); + std::vector ref_packed_lhs; + + if (has_lhs_pack) { + ref_packed_lhs = + pack(method.packed_lhs_format, lhs.data(), nullptr, nullptr, method.lhs_format, lhs_h, lhs_w); + } + + const auto rhs_h = method.rhs_transposed ? info.n : info.k; + const auto rhs_w = method.rhs_transposed ? info.k : info.n; + auto rhs = fill_matrix_random(rhs_h, rhs_w, method.rhs_format, 1); + + std::vector rhs_scales; + if (data_type_is_quantized(method.rhs_format.data_type()) && + method.rhs_format.pack_format() == DataFormat::PackFormat::NONE) { + rhs_scales = fill_matrix_random(rhs_h, 1, DataFormat(DataType::FP32), 2); + } + + const auto bias_h = 1; + const auto bias_w = info.n; + std::vector bias; + + if (has_bias) { + bias = fill_matrix_random(bias_h, bias_w, method.bias_format, 3); + } + + std::vector packed_rhs; + if (has_rhs_pack) { + packed_rhs = matmul_pack_rhs( + rhs.data(), !rhs_scales.empty() ? rhs_scales.data() : nullptr, bias.data(), method.rhs_format, + method.packed_rhs_format, info.n, info.k, !method.rhs_transposed); + } + + KAI_ASSUME(method.lhs_format.is_raw()); + KAI_ASSUME(method.rhs_format.is_raw()); + KAI_ASSUME(method.dst_format.is_raw()); + + auto ref_dst = matmul( + lhs.data(), nullptr, nullptr, method.lhs_format.data_type(), // + rhs.data(), rhs_scales.data(), nullptr, method.rhs_format.data_type(), // + bias.data(), nullptr, nullptr, method.bias_format.data_type(), // + method.dst_format.data_type(), // + info.m, info.n, info.k, method.lhs_transposed, method.rhs_transposed); + + const auto& data = _data[data_id] = { + .lhs = std::move(lhs), + .ref_packed_lhs = std::move(ref_packed_lhs), + .rhs = std::move(rhs), + .rhs_scales = std::move(rhs_scales), + .bias = std::move(bias), + .ref_packed_rhs = std::move(packed_rhs), + .ref_dst = std::move(ref_dst), + }; + + return data; + } + +private: + // NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables) + static std::map _data; + // NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables) +}; + +// NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables) +std::map MatMulTestBf16::_data; +// NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables) + +/// Tests the output. +TEST_P(MatMulTestBf16, Output) { + const auto& [method, info, portion] = GetParam(); + const auto& data = test_data(); + + if (method.is_sme2 && !cpu_has_sme2()) { + GTEST_SKIP(); + } + + if (!method.has_main_kernel()) { + GTEST_SKIP(); + } + + const auto m_step = method.fn_get_main_m_step(); + ASSERT_EQ(m_step, method.m0); + + const auto n_step = method.fn_get_main_n_step(); + ASSERT_EQ(n_step, method.n0); + + const auto rect = portion.compute_portion(info.m, info.n, method.m0, method.n0); + + if (rect.height() == 0 || rect.width() == 0) { + GTEST_SKIP(); + } + + const auto lhs_w = method.lhs_transposed ? info.m : info.k; + const auto rhs_w = method.rhs_transposed ? info.k : info.n; + const auto bias_w = info.n; + const auto dst_w = info.n; + + const auto lhs_start_row = method.lhs_transposed ? 0 : rect.start_row(); + const auto lhs_start_col = method.lhs_transposed ? rect.start_row() : 0; + const auto lhs_stride = method.lhs_format.default_row_stride(lhs_w); + + const uint8_t* lhs_data = nullptr; + uintptr_t lhs_offset = 0; + + if (method.is_pack_lhs_needed()) { + lhs_data = data.ref_packed_lhs.data(); + + const auto ref_lhs_offset = + method.packed_lhs_format.default_offset_in_bytes(lhs_start_row, lhs_start_col, info.k); + KAI_UNUSED(ref_lhs_offset); + + lhs_offset = method.fn_get_packed_lhs_offset(lhs_start_row, info.k); + + // TODO: Check with ref_lhs_offset after fixing default_offset_in_bytes() + } else { + lhs_data = data.lhs.data(); + + lhs_offset = method.fn_get_lhs_offset(lhs_start_row, lhs_stride); + const auto ref_lhs_offset = method.lhs_format.default_offset_in_bytes(lhs_start_row, lhs_start_col, lhs_w); + ASSERT_EQ(lhs_offset, ref_lhs_offset); + } + + const auto rhs_stride = method.rhs_format.default_row_stride(rhs_w); + + const uint8_t* rhs_data = nullptr; + uintptr_t rhs_offset = 0; + + if (method.is_pack_rhs_needed()) { + const auto packed_rhs_start_row = rect.start_col(); + const auto packed_rhs_start_col = 0; + + rhs_data = data.ref_packed_rhs.data(); + + rhs_offset = method.fn_get_main_packed_rhs_offset(packed_rhs_start_row, info.k); + const auto ref_rhs_offset = + method.packed_rhs_format.default_offset_in_bytes(packed_rhs_start_row, packed_rhs_start_col, info.k); + + ASSERT_EQ(rhs_offset, ref_rhs_offset); + } else { + const auto rhs_start_row = method.rhs_transposed ? rect.start_col() : 0; + const auto rhs_start_col = method.rhs_transposed ? 0 : rect.start_col(); + + rhs_data = data.rhs.data(); + rhs_offset = method.rhs_format.default_offset_in_bytes(rhs_start_row, rhs_start_col, rhs_w); + } + + const auto* bias_data = data.bias.data(); + const auto bias_offset = method.bias_format.default_offset_in_bytes(0, rect.start_row(), bias_w); + + const auto dst_stride = method.dst_format.default_row_stride(dst_w); + const auto dst_offset = method.fn_get_dst_offset(rect.start_row(), rect.start_col(), dst_stride); + const auto ref_dst_offset = method.dst_format.default_offset_in_bytes(rect.start_row(), rect.start_col(), dst_w); + ASSERT_EQ(dst_offset, ref_dst_offset); + + const auto dst_size = method.fn_get_dst_size(info.m, info.n); + const auto ref_dst_size = method.dst_format.default_size_in_bytes(info.m, info.n); + ASSERT_EQ(dst_size, ref_dst_size); + + std::vector dst; + dst.resize(dst_size); + + method.main_kernel( + rect.height(), rect.width(), info.k, lhs_data + lhs_offset, rhs_data + rhs_offset, bias_data + bias_offset, + dst.data() + dst_offset, lhs_stride, rhs_stride, dst_stride, -std::numeric_limits::infinity(), + std::numeric_limits::infinity()); + + DefaultMismatchHandler handler(0, 0.1, 0, 0.05); + const auto success = compare(dst.data(), data.ref_dst.data(), method.dst_format, info.m, info.n, rect, handler); + + ASSERT_TRUE(success); +} + +INSTANTIATE_TEST_SUITE_P( + MatMul, MatMulTestBf16, + testing::Combine( + testing::ValuesIn(matmul_methods), + testing::Values( + MatMulShape{3, 7, 3}, // Smaller than block size + MatMulShape{12, 8, 4}, // Same block size + MatMulShape{1, 1, 1023}, // Long K + MatMulShape{1013, 1, 5}, // Long M + MatMulShape{2, 1013, 6}, // Long N + MatMulShape{13, 33, 23}, // + MatMulShape{93, 57, 89}, // + MatMulShape{256, 256, 256}, // Nice shapes + MatMulShape{257, 113, 373} // Prime numbers + ), + testing::Values( + MatrixPortion(0, 0, 1, 1), // Full matrix. + MatrixPortion(0, 0, 0.25, 0.25), // Top-left corner. + MatrixPortion(0.75, 0.75, 1, 1), // Bottom-right corner. + MatrixPortion(0.75, 0, 1, 1), // Partial rows + MatrixPortion(0.4, 0.5, 0.6, 0.8) // Somewhere Middle + )), + testing::PrintToStringParamName()); + +} // namespace kai::test diff --git a/test/tests/matmul_test.cpp b/test/tests/matmul_test.cpp index f2d5e544..90f1f1df 100644 --- a/test/tests/matmul_test.cpp +++ b/test/tests/matmul_test.cpp @@ -21,11 +21,13 @@ #include #include "kai/kai_common.h" +#include "test/common/MatMulMethod.hpp" #include "test/common/compare.hpp" #include "test/common/cpu_info.hpp" #include "test/common/data_format.hpp" #include "test/common/data_type.hpp" #include "test/common/float16.hpp" +#include "test/common/matmul_test_common.hpp" #include "test/common/matrix_portion.hpp" #include "test/common/printer.hpp" #include "test/common/sme.hpp" @@ -46,296 +48,6 @@ #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.h" namespace kai::test { -// NOLINTBEGIN(misc-non-private-member-variables-in-classes) - -/// Matrix multiplication method. -struct MatMulMethod { - std::string_view name; ///< Name of matmul method. - - size_t m0; ///< Block size in M dimension. - size_t n0; ///< Block size in N dimension. - - bool lhs_transposed; ///< LHS matrix is transposed. - bool rhs_transposed; ///< RHS matrix is transposed. - - DataFormat dst_format; ///< Data format of the destination matrix. - DataFormat lhs_format; ///< Data format of the LHS matrix. - DataFormat packed_lhs_format; ///< Data format of the packed LHS matrix. - DataFormat rhs_format; ///< Data format of the RHS matrix. - DataFormat packed_rhs_format; ///< Data format of the packed RHS matrix. - DataFormat bias_format; ///< Data format of the bias vector. - - /// Check if CPU supports required features. - /// - /// @return Supported (true) or not supported (false). - std::function fn_is_supported; - - /// Gets mr value. - /// - /// This is the packing parameter which must be used to pack the LHS matrix (if necessary). - /// - /// @return The mr value. - std::function fn_get_mr; - - /// Gets nr value. - /// - /// This is the packing parameter which must be used to pack the RHS matrix (if necessary). - /// - /// @return The nr value. - std::function fn_get_nr; - - /// Gets kr value. - /// - /// This is the packing parameter which must be used to pack the LHS and RHS matrix (if necessary). - /// - /// @return The kr value. - std::function fn_get_kr; - - /// Gets sr value. - /// - /// This is the packing parameter which must be used to pack the RHS matrix. - /// - /// @return The sr value. - std::function fn_get_sr; - - /// Gets m step value for main kernel. - /// - /// The starting row index must be divisible by `m_step`. - /// - /// @return The m step value. - std::function fn_get_main_m_step; - - /// Gets n step value for RHS packing kernel. - /// - /// The starting row index must be divisible by `n_step`. - /// - /// @return The n step value. - std::function fn_get_pack_rhs_n_step; - - /// Gets n step value for main kernel. - /// - /// The starting column index must be divisible by `n_step`. - /// - /// @return The n step value. - std::function fn_get_main_n_step; - - /// Gets the offset in bytes of the LHS matrix. - /// - /// @param[in] m_idx Coordinate of the matrix in M dimension. - /// @param[in] stride Row stride in bytes. - /// - /// @return The offset in bytes. - std::function fn_get_lhs_offset; - - /// Gets the size in bytes of the packed LHS matrix. - /// - /// @param[in] m Number of rows in the unpacked LHS matrix. - /// @param[in] k Number of columns in the unpacked LHS matrix. - /// @param[in] mr Number of rows to be interleaved. - /// @param[in] kr Unused. Must be 1. - /// @param[in] sr Unused. Must be 1. - /// - /// @return The size in bytes. - std::function fn_get_packed_lhs_size; - - /// Gets the offset in bytes of the packed LHS matrix. - /// - /// @param[in] m_idx Coordinate of the matrix in M dimension. - /// @param[in] k Size of the matrix in K dimension. - /// - /// @return The offset in bytes. - std::function fn_get_packed_lhs_offset; - - /// Preprocesses the LHS matrix. - /// - /// @param[in] m Number of rows of the unpacked LHS matrix. - /// @param[in] k Common dimension between the LHS and RHS matrix. - /// @param[in] mr Block size in M dimension. It must be {{ kernel.interleave_by }}VL. - /// @param[in] kr Block size in K dimension. It must be {{ kernel.block_by }}. - /// @param[in] sr Number of kr splits. It must be 1. - /// @param[in] m_idx_start Unused. Must be 0. - /// @param[in] lhs LHS matrix data buffer. - /// @param[in] lhs_stride Row stride in bytes of the LHS matrix. - /// @param[out] lhs_packed Packed RHS matrix. - std::function - fn_pack_lhs; - - /// Gets a value indicating whether LHS packing is needed. - [[nodiscard]] bool is_pack_lhs_needed() const { - return fn_pack_lhs != nullptr; - } - - /// Gets the offset in bytes of the RHS matrix. - /// - /// @param[in] n_idx Coordinate of the matrix in N dimension. - /// - /// @return The offset in bytes. - std::function fn_get_rhs_offset; - - /// Gets the size in bytes of the packed RHS matrix. - /// - /// @param[in] n Size of the matrix in N dimension. - /// @param[in] k Size of the matrix in K dimension. - /// - /// @return The size in bytes. - std::function fn_get_packed_rhs_size; - - /// Gets the offset in bytes of the packed RHS matrix in the RHS packing kernel - /// - /// @param[in] n_idx Coordinate of the matrix in N dimension. - /// @param[in] k Size of the matrix in K dimension. - /// - /// @return The offset in bytes. - std::function fn_get_pack_rhs_packed_rhs_offset; - - /// Gets the offset in bytes of the packed RHS matrix in the main kernel. - /// - /// @param[in] n_idx Coordinate of the matrix in N dimension. - /// @param[in] k Size of the matrix in K dimension. - /// - /// @return The offset in bytes. - std::function fn_get_main_packed_rhs_offset; - - std::function - fn_pack_rhs; - - /// Gets the offset in bytes to the data element in the bias buffer. - /// - /// @param[in] n_idx Column index. - /// - /// @return The offset in bytes to the data element. - std::function fn_get_bias_offset; - - /// Gets the offset in bytes to the data element in the destination matrix buffer. - /// - /// @param[in] m_idx Row index. - /// @param[in] n_idx Column index. - /// @param[in] stride Row stride in bytes. - /// - /// @return The offset in bytes to the data element. - std::function fn_get_dst_offset; - - /// Gets the size in bytes of the destination matrix buffer. - /// - /// @param[in] m Number of rows. - /// @param[in] n Number of columns. - /// - /// @return The size in bytes of the destination matrix buffer. - std::function fn_get_dst_size; - - /// Performs F16 or F32 matrix multiplication with RHS packing - /// followed by clamp operation. - /// - /// @param[in] m Size of the matrix in M dimension. - /// @param[in] n Size of the matrix in N dimension. - /// @param[in] k Size of the matrix in K dimension. - /// @param[in] lhs LHS data buffer. - /// @param[in] packed_rhs Packed RHS data buffer. - /// @param[out] dst Output data buffer. - /// @param[in] lhs_stride LHS row stride. - /// @param[in] dst_stride Output row stride. - /// @param[in] clamp_min Lower bound of the output data. - /// @param[in] clamp_max Upper bound of the output data. - std::function - fn_matmul_f16_f16_f16p; - - std::function - fn_matmul_f32_f32_f32p; - - /// Performs F32 matrix multiplication with LHS & RHS packing - /// followed by clamp operation. - /// - /// @param[in] m Number of output rows to be computed. - /// @param[in] n Number of output columns to be computed. - /// @param[in] k Common dimension of the LHS and RHS operands. - /// @param[in] packed_lhs Packed LHS matrix buffer. - /// @param[in] packed_rhs Packed RHS matrix buffer. - /// @param[out] dst Output matrix buffer. - /// @param[in] dst_stride_row Row stride in bytes of the output matrix. - /// @param[in] dst_stride_col Column stride in bytes of the output matrix. - /// @param[in] clamp_min Minimum value to clamp the final result. - /// @param[in] clamp_max Maximum value to clamp the final result. - std::function - fn_matmul_f32_f32p_f32p; - - /// Gets a value indicating whether pre-processing the RHS matrix is needed. - [[nodiscard]] bool is_pack_rhs_needed() const { - return fn_pack_rhs != nullptr; - } - - /// Preprocesses the RHS matrix. - /// - /// @param[in] n Size of the matrix in N dimension. - /// @param[in] k Size of the matrix in K dimension. - /// @param[in] rhs RHS data buffer. - /// @param[in] rhs_row_stride RHS row stride. - /// @param[in] bias Bias data buffer. - /// @param[in] scale Quantization scales data buffer. - /// @param[out] packed_rhs Packed RHS data buffer. - void pack_rhs( - size_t n, size_t k, const void* rhs, size_t rhs_row_stride, const void* bias, const void* scale, - void* packed_rhs) const { - KAI_UNUSED(n); - KAI_UNUSED(k); - KAI_UNUSED(rhs); - KAI_UNUSED(rhs_row_stride); - KAI_UNUSED(bias); - KAI_UNUSED(scale); - KAI_UNUSED(packed_rhs); - - if (fn_pack_rhs != nullptr) { - fn_pack_rhs( - 1, n, k, fn_get_nr(), fn_get_kr(), fn_get_sr(), rhs_row_stride, rhs, bias, nullptr, packed_rhs, 0, - nullptr); - } else { - KAI_ERROR("RHS pre-processing is not supported!"); - } - } - - [[nodiscard]] bool has_main_kernel() const { - return fn_matmul_f16_f16_f16p != nullptr || fn_matmul_f32_f32p_f32p != nullptr || - fn_matmul_f32_f32_f32p != nullptr; - } - - void main_kernel( - size_t m, size_t n, size_t k, const void* lhs, const void* rhs, const void* bias, void* dst, size_t lhs_stride, - size_t rhs_stride, size_t dst_stride, float clamp_min, float clamp_max) const { - KAI_UNUSED(bias); - KAI_UNUSED(rhs_stride); - if (fn_matmul_f16_f16_f16p) { - fn_matmul_f16_f16_f16p( - m, n, k, lhs, lhs_stride, rhs, dst, dst_stride, sizeof(Float16), clamp_min, - static_cast(clamp_max)); - } else if (fn_matmul_f32_f32_f32p) { - fn_matmul_f32_f32_f32p( - m, n, k, lhs, lhs_stride, rhs, dst, dst_stride, sizeof(float), clamp_min, - static_cast(clamp_max)); - } else if (fn_matmul_f32_f32p_f32p) { - fn_matmul_f32_f32p_f32p(m, n, k, lhs, rhs, dst, dst_stride, sizeof(float), clamp_min, clamp_max); - } else { - KAI_ERROR("Main kernel is not available!"); - } - } -}; - -// NOLINTEND(misc-non-private-member-variables-in-classes) - /// List of supported matrix multiplication methods. static const std::array matmul_methods = { MatMulMethod{ @@ -486,35 +198,11 @@ static const std::array matmul_methods = { }, }; -/// Matrix multiplication shape. -struct MatMulShape { - size_t m; ///< LHS height. - size_t n; ///< RHS width. - size_t k; ///< LHS width and RHS height. -}; - -/// Matrix multiplication test information. -using MatMulTestParams = std::tuple; - -/// Prints the test information. -void PrintTo(const MatMulTestParams& param, std::ostream* os) { - const auto& [method_no, shape, portion] = param; - - // NOLINTBEGIN(cppcoreguidelines-pro-bounds-constant-array-index) - *os << "Method_" << matmul_methods[method_no].name // - << "__M_" << shape.m << "__N_" << shape.n << "__K_" << shape.k // - << "__PortionStartRow_" << static_cast(portion.start_row() * 1000) // - << "__PortionStartCol_" << static_cast(portion.start_col() * 1000) // - << "__PortionHeight_" << static_cast(portion.height() * 1000) // - << "__PortionWidth_" << static_cast(portion.width() * 1000); - // NOLINTEND(cppcoreguidelines-pro-bounds-constant-array-index) -} - /// Matrix multiplication test fixture. class MatMulTest : public testing::TestWithParam { private: /// Unique ID: m, n, k, method_id. - using TestDataId = std::tuple; + using TestDataId = std::tuple; protected: /// Cached test data that is shared between multiple test case. @@ -530,8 +218,8 @@ protected: /// Gets the test data for the current test case. static const TestData& test_data() { - const auto& [method_no, info, portion] = GetParam(); - const TestDataId data_id{info.m, info.n, info.k, method_no}; + const auto& [method, info, portion] = GetParam(); + const TestDataId data_id{info.m, info.n, info.k, method.name}; // If the test data is already available, returns it. const auto data_it = _data.find(data_id); @@ -541,8 +229,6 @@ protected: } // Generates the test data. - const auto& method = matmul_methods.at(method_no); - const auto has_lhs_pack = method.packed_lhs_format.data_type() != DataType::UNKNOWN; const auto has_rhs_pack = method.packed_rhs_format.data_type() != DataType::UNKNOWN; const auto has_bias = method.bias_format.data_type() != DataType::UNKNOWN; @@ -617,9 +303,8 @@ std::map MatMulTest::_data; /// Tests the LHS packing kernel. TEST_P(MatMulTest, PackedLhs) { - const auto& [method_no, info, portion] = GetParam(); + const auto& [method, info, portion] = GetParam(); const auto& data = test_data(); - const auto& method = matmul_methods.at(method_no); if (method.fn_is_supported && !method.fn_is_supported()) { GTEST_SKIP(); @@ -668,9 +353,8 @@ TEST_P(MatMulTest, PackedLhs) { /// Tests the RHS packing kernel. TEST_P(MatMulTest, PackedRhs) { - const auto& [method_no, info, portion] = GetParam(); + const auto& [method, info, portion] = GetParam(); const auto& data = test_data(); - const auto& method = matmul_methods.at(method_no); if (method.fn_is_supported && !method.fn_is_supported()) { GTEST_SKIP(); @@ -739,9 +423,8 @@ TEST_P(MatMulTest, PackedRhs) { /// Tests the output. TEST_P(MatMulTest, Output) { - const auto& [method_no, info, portion] = GetParam(); + const auto& [method, info, portion] = GetParam(); const auto& data = test_data(); - const auto& method = matmul_methods.at(method_no); if (method.fn_is_supported && !method.fn_is_supported()) { GTEST_SKIP(); @@ -837,7 +520,7 @@ TEST_P(MatMulTest, Output) { INSTANTIATE_TEST_SUITE_P( MatMul, MatMulTest, testing::Combine( - testing::Range(0, matmul_methods.size()), + testing::ValuesIn(matmul_methods), testing::Values( MatMulShape{1, 16, 16}, // MatMulShape{20, 1, 20}, // -- GitLab From 5869f8dfb3136042a68448b75e48e52adb9a2359 Mon Sep 17 00:00:00 2001 From: Gunes Bayir Date: Mon, 30 Sep 2024 17:30:43 +0300 Subject: [PATCH 03/10] Add optional bias support Signed-off-by: Gunes Bayir --- ...s_pack_kxn_f32p4x12biasf32_f32_bf16_neon.c | 20 +++++-- test/common/MatMulMethod.hpp | 1 - .../matmul_clamp_f32_bf16p_bf16p_test.cpp | 56 +++++++++++++++++-- 3 files changed, 68 insertions(+), 9 deletions(-) diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon.c b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon.c index 943067e5..9e7c7dc9 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon.c @@ -51,7 +51,6 @@ void kai_run_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon( KAI_ASSUME(kr == kai_kr); KAI_ASSUME(sr == 1); KAI_ASSUME(rhs != NULL); - KAI_ASSUME(bias != NULL); KAI_ASSUME(scale == NULL); KAI_ASSUME(rhs_packed != NULL); KAI_ASSUME(extra_bytes == 0); @@ -68,6 +67,18 @@ void kai_run_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon( memset(pad_row, 0, width * sizeof(float)); } + // Fill zeros if bias is nullptr + size_t bias_step = nr * sizeof(float); + + void* zero_bias = NULL; + if (bias == NULL) { + zero_bias = alloca(bias_step); + memset(zero_bias, 0, bias_step); + bias_step = 0; + } + + const void* bias_ptr = bias == NULL ? zero_bias : bias; + size_t out_stride = kai_nr * kai_roundup(height, 4) * sizeof(bfloat16_t) + kai_nr * sizeof(uint32_t); __asm__ __volatile__( @@ -81,7 +92,7 @@ void kai_run_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon( "sub x22, x22, #0xc\n" "ldr q8, [%x[bias], #0x20]\n" "cmp x22, #0xc\n" - "add %x[bias], %x[bias], #0x30\n" + "add %x[bias], %x[bias], %x[bias_step]\n" "str q16, [x21, #0x0]\n" "str q26, [x21, #0x10]\n" "str q8, [x21, #0x20]\n" @@ -451,8 +462,9 @@ void kai_run_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon( "add %x[out], %x[out], #0x60\n" "bge 13b\n" "21:" // Done - : [bias] "+&r"(bias), [height] "+&r"(height), [in] "+&r"(in), [out] "+&r"(out) - : [in_stride] "r"(in_stride), [out_stride] "r"(out_stride), [pad_row] "r"(pad_row), [width] "r"(width) + : [bias] "+&r"(bias_ptr), [height] "+&r"(height), [in] "+&r"(in), [out] "+&r"(out) + : [bias_step] "r"(bias_step), [in_stride] "r"(in_stride), [out_stride] "r"(out_stride), [pad_row] "r"(pad_row), + [width] "r"(width) : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"); diff --git a/test/common/MatMulMethod.hpp b/test/common/MatMulMethod.hpp index daf91a7e..127e4910 100644 --- a/test/common/MatMulMethod.hpp +++ b/test/common/MatMulMethod.hpp @@ -286,7 +286,6 @@ struct MatMulMethod { KAI_UNUSED(rhs_row_stride); KAI_UNUSED(bias); KAI_UNUSED(scale); - KAI_UNUSED(packed_rhs); if (fn_pack_rhs != nullptr) { fn_pack_rhs( diff --git a/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp index 493af592..104a507a 100644 --- a/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp +++ b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp @@ -90,7 +90,54 @@ static const std::array matmul_methods = { .fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, }, -}; + MatMulMethod{ + .name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla_opt_bias", + + .m0 = 8, + .n0 = 12, + + .lhs_transposed = false, + .rhs_transposed = false, + + .is_sme2 = false, + + .dst_format = DataFormat(DataType::FP32), + .lhs_format = DataFormat(DataType::FP32), + .packed_lhs_format = + DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP32, DataType::UNKNOWN, 8, 4), + .rhs_format = DataFormat(DataType::FP32), + .packed_rhs_format = DataFormat( + DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4), + .bias_format = DataFormat(DataType::UNKNOWN), + + .fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + .fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + .fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + .fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + + .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon, + .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + + .fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_f32p8x4_bf16_neon, + .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_f32p8x4_bf16_neon, + .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_lhs_pack_f32p8x4_bf16_neon, + .fn_pack_lhs = kai_run_lhs_pack_f32p8x4_bf16_neon, + + .fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon, + .fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon, + .fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon, + .fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + .fn_pack_rhs = kai_run_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon, + + .fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon, + + .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + + .fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + }}; /// Matrix multiplication test fixture. class MatMulTestBf16 : public testing::TestWithParam { @@ -156,10 +203,11 @@ protected: } std::vector packed_rhs; + packed_rhs.resize(method.packed_rhs_format.default_size_in_bytes(rhs_h, rhs_w)); + if (has_rhs_pack) { - packed_rhs = matmul_pack_rhs( - rhs.data(), !rhs_scales.empty() ? rhs_scales.data() : nullptr, bias.data(), method.rhs_format, - method.packed_rhs_format, info.n, info.k, !method.rhs_transposed); + const auto ref_rhs_row_stride = method.rhs_format.default_row_stride(rhs_w); + method.pack_rhs(info.n, info.k, rhs.data(), ref_rhs_row_stride, bias.data(), nullptr, packed_rhs.data()); } KAI_ASSUME(method.lhs_format.is_raw()); -- GitLab From 79f6780ca0b91601056de0a55a8303b15234a03d Mon Sep 17 00:00:00 2001 From: Gunes Bayir Date: Thu, 10 Oct 2024 11:42:26 +0100 Subject: [PATCH 04/10] Addressed comments in the merge request Signed-off-by: Gunes Bayir --- .gitlab-ci.yml | 2 + .../CMakeLists.txt | 4 +- .../matmul_clamp_f32_bf16p_bf16p.cpp | 71 ++-- kai/kai_common.h | 2 +- kai/ukernels/matmul/BUILD.bazel | 7 +- ..._bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.c | 9 +- ..._bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.h | 10 +- ..._matmul_clamp_f32_bf16p_bf16p_interface.h} | 6 +- .../pack/kai_lhs_pack_f32p8x4_bf16_neon.h | 48 +++ test/common/MatMulMethod.hpp | 329 ------------------ test/common/matmul_test_common.cpp | 1 - test/common/matmul_test_common.hpp | 322 ++++++++++++++++- test/common/memory.hpp | 12 +- test/reference/pack.cpp | 15 +- test/reference/pack.hpp | 4 +- .../matmul_clamp_f32_bf16p_bf16p_test.cpp | 20 +- test/tests/matmul_test.cpp | 4 +- 17 files changed, 443 insertions(+), 423 deletions(-) rename kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/{matmul_clamp_f32_bf16p_bf16p_interface.h => kai_matmul_clamp_f32_bf16p_bf16p_interface.h} (90%) delete mode 100644 test/common/MatMulMethod.hpp diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index b9c625c7..e8b0b2f3 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -108,6 +108,7 @@ build-examples: matrix: - EXAMPLE: - matmul_clamp_f16_f16_f16p + - matmul_clamp_f32_bf16p_bf16p - matmul_clamp_f32_qai8dxp_qsi4cxp - matmul_clamp_f32_qsi8d32p_qsi4c32p - matmul_clamp_f32_qai8dxp_qsi4c32p @@ -130,6 +131,7 @@ test-examples: matrix: - EXAMPLE: - matmul_clamp_f16_f16_f16p + - matmul_clamp_f32_bf16p_bf16p - matmul_clamp_f32_qai8dxp_qsi4cxp - matmul_clamp_f32_qsi8d32p_qsi4c32p - matmul_clamp_f32_qai8dxp_qsi4c32p diff --git a/examples/matmul_clamp_f32_bf16p_bf16p/CMakeLists.txt b/examples/matmul_clamp_f32_bf16p_bf16p/CMakeLists.txt index 62007cf7..ff2fd03e 100644 --- a/examples/matmul_clamp_f32_bf16p_bf16p/CMakeLists.txt +++ b/examples/matmul_clamp_f32_bf16p_bf16p/CMakeLists.txt @@ -4,7 +4,9 @@ # SPDX-License-Identifier: Apache-2.0 # -cmake_minimum_required(VERSION 3.16) +cmake_minimum_required(VERSION 3.18) + +project(KleidiAI) set(CMAKE_CXX_STANDARD 17) set(KLEIDIAI_PATH ../../) diff --git a/examples/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p.cpp b/examples/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p.cpp index 49f5a1f9..34b6db5e 100644 --- a/examples/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p.cpp +++ b/examples/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p.cpp @@ -4,44 +4,38 @@ // SPDX-License-Identifier: Apache-2.0 // -// Example usage for matrix multiplication of two half precision floating-point (FP16) matrices and the accumulation of -// the result into an FP16 destination matrix. +// Example usage for matrix multiplication of two half-precision brain floating-point (BF16) matrices +// and the accumulation of the result into an FP32 destination matrix. // // The activations and the weights, stored in the LHS and RHS matrices respectively, are both non-transposed matrices. -// The matrix multiplication computation is performed using floating-point fused multiply-add to accumulator (FMLA) -// vector instructions present in the FEAT_FP16 Arm® architecture feature. +// The matrix multiplication computation is performed using BF16 matrix multiply (BFMMLA) +// vector instructions present in the FEAT_BF16 Arm® architecture feature. // -#if !defined(__aarch64__) || !defined(__ARM_FEATURE_FP16_SCALAR_ARITHMETIC) || \ - !defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) -#error This file must be compiled for AArch64, FEAT_FP16. +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_BF16_SCALAR_ARITHMETIC) || \ + !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) +#error This file must be compiled for AArch64, FEAT_BF16. #else #include #include -#include #include #include #include #include +#include #include #include // Include micro-kernel variants +#include "kai/kai_common.h" #include "kai_lhs_pack_f32p8x4_bf16_neon.h" #include "kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.h" +#include "kai_matmul_clamp_f32_bf16p_bf16p_interface.h" #include "kai_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon.h" -#include "matmul_clamp_f32_bf16p_bf16p_interface.h" - -inline float bf16_to_float(uint16_t v) { - const uint32_t lv = (v << 16); - float fp; - memcpy(&fp, &lv, sizeof(lv)); - return fp; -} inline float bf16_to_float(const bfloat16_t* v) { const uint16_t uint_rep = *reinterpret_cast(v); - return bf16_to_float(uint_rep); + return kai_cast_f32_bf16(uint_rep); } namespace { @@ -53,12 +47,15 @@ constexpr kai_matmul_clamp_f32_bf16p_bf16p_ukernel ukernel{ kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, - kai_get_lhs_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, kai_run_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla}; +/// @brief Truncate the 32-bit floating point number's least significant 16 mantissa bits +/// @param x floating-point number +/// @return truncated floating-point number float truncate(float x) { uint32_t uval = (*reinterpret_cast(&x) & 0xffff0000); return *reinterpret_cast(&uval); @@ -78,10 +75,8 @@ void run_matmul_ref( acc += lhs_val * rhs_val; } - acc = std::max(acc, scalar_min); - acc = std::min(acc, scalar_max); - dst[row_idx * n + col_idx] = acc; + dst[row_idx * n + col_idx] = std::clamp(acc, scalar_min, scalar_max); } } } @@ -220,12 +215,12 @@ int main() { float* dst_ref = new float[dst_size]; run_matmul_ref( - M, N, K, // Dimensions - lhs, // LHS buffer - rhs, // RHS buffer - bias, // Bias buffer - dst_ref, // DST - FLT_MIN, FLT_MAX // Min and max for the clamp operation + M, N, K, // Dimensions + lhs, // LHS buffer + rhs, // RHS buffer + bias, // Bias buffer + dst_ref, // DST + -FLT_MAX, FLT_MAX // Min and max for the clamp operation ); //----------- END REFERENCE IMPLEMENTATION //------------------------------------ @@ -275,18 +270,18 @@ int main() { float* dst = new float[dst_size]; - kai_run_lhs_pack_f32p8x4_bf16_neon(M, K, mr, kr, sr, 0 /* m_idx_start */, lhs, lhs_stride, lhs_packed); - const auto timer_matmul_start = std::chrono::high_resolution_clock::now(); + kai_run_lhs_pack_f32p8x4_bf16_neon(M, K, mr, kr, sr, 0 /* m_idx_start */, lhs, lhs_stride, lhs_packed); + ukernel.run_matmul( - M, N, K, // Dimensions - lhs_packed, // LHS packed - rhs_packed, // RHS packed - dst, // DST - dst_stride_row, // DST stride (row) - dst_stride_col, // DST stride (col) - FLT_MIN, FLT_MAX // Min and max for the clamp operation + M, N, K, // Dimensions + lhs_packed, // LHS packed + rhs_packed, // RHS packed + dst, // DST + dst_stride_row, // DST stride (row) + dst_stride_col, // DST stride (col) + -FLT_MAX, FLT_MAX // Min and max for the clamp operation ); const auto timer_matmul_end = std::chrono::high_resolution_clock::now(); @@ -302,7 +297,8 @@ int main() { print_matrix(M, N, "ref", dst_ref); #endif // KAI_DEBUG - const bool is_valid = is_output_correct(M, N, 0.02 /* rel tol */, dst_ref, dst); + constexpr float rel_tolerance = 0.02; // This value was chosen by experimentation + const bool is_valid = is_output_correct(M, N, rel_tolerance, dst_ref, dst); std::cout << "TEST[matmul_clamp_f32_bf16p_bf16p]\n"; std::cout << "- ukernel: matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla\n"; @@ -321,6 +317,7 @@ int main() { delete[] lhs; delete[] rhs; delete[] bias; + delete[] lhs_packed; delete[] rhs_packed; delete[] dst; delete[] dst_ref; diff --git a/kai/kai_common.h b/kai/kai_common.h index 8fe70424..27034185 100644 --- a/kai/kai_common.h +++ b/kai/kai_common.h @@ -23,7 +23,7 @@ extern "C" { #define KAI_ERROR(msg) \ do { \ fflush(stdout); \ - fprintf(stderr, "%s:%d %s\n", __FILE__, __LINE__, msg); \ + fprintf(stderr, "%s:%d %s", __FILE__, __LINE__, msg); \ exit(EXIT_FAILURE); \ } while (0) diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index b2f6ab58..8fc817ab 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -41,8 +41,8 @@ kai_c_library( kai_c_library( name = "clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla", - srcs = ["matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.c"], - hdrs = ["matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.h"], + srcs = ["matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.c"], + hdrs = ["matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.h"], cpu_uarch = kai_cpu_bf16(), deps = [ ":clamp_f32_bf16p_bf16p_interface", @@ -337,6 +337,7 @@ kai_c_library( name = "matmul", deps = [ ":clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla", + ":clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla", ":clamp_f32_f32_f32p", ":clamp_f32_f32_f32pb_1x16vl_sme2_mla", ":clamp_f32_f32p_f32p", @@ -358,11 +359,13 @@ kai_c_library( ":clamp_f32_qsi8d32p_qsi4c32p_i8mm", ":clamp_f32_qsi8d32p_qsi4c32p_interface", ":lhs_pack_f32p2vlx1_f32_sme", + ":lhs_pack_f32p8x4_bf16_neon", ":lhs_quant_pack_qai8dxp_f32", ":lhs_quant_pack_qsi8d32p_f32", ":rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon", ":rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme", ":rhs_pack_kxn_f32pb_f32_f32_16vlx1_sme", + ":rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon", ":rhs_pack_kxn_f32pbiasf32_f32_f32_neon", ":rhs_pack_kxn_qsi4c32p_qsu4c32s1s0", ":rhs_pack_kxn_qsi4cxp_qs4cxs1s0", diff --git a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.c b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.c index 13ebc4e3..7132d15b 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.c @@ -10,7 +10,6 @@ #include #include -#include #include #include @@ -48,16 +47,15 @@ size_t kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(void) return kai_sr; } -size_t kai_get_lhs_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(size_t m_idx, size_t stride) { +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(size_t m_idx, size_t k) { KAI_ASSUME(m_idx % kai_mr == 0); - return m_idx * stride; + return m_idx * kai_roundup(k, kai_kr) * sizeof(bfloat16_t); } size_t kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(size_t n_idx, size_t k) { KAI_ASSUME(n_idx % kai_nr == 0); - // return n_idx / kai_nr * (kai_nr * sizeof(float) + kai_nr * k * sizeof(bfloat16)); return n_idx * (sizeof(float) + kai_roundup(k, kai_kr) * sizeof(bfloat16_t)); } @@ -82,7 +80,6 @@ void kai_run_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla( KAI_ASSERT(dst_stride_col == sizeof(float)); const void* Apanel = lhs_packed; - // const void *Bpanel = rhs_packed; void* Cpanel = dst; size_t ldc = dst_stride_row / sizeof(float); @@ -100,7 +97,7 @@ void kai_run_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla( KernelArgs ka; ka.N = n; - ka.K = kai_roundup(k, 4) / 4 - 1; + ka.K = kai_roundup(k, kai_kr) / kai_kr - 1; ka.Bpanel = rhs_packed; diff --git a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.h b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.h index 000d690a..316933ed 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.h @@ -60,13 +60,13 @@ size_t kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(void) /// @return The sr value. size_t kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(void); -/// Gets the offset in bytes to the data element in the LHS matrix buffer. +/// Gets the offset in bytes to the data element in the packed LHS matrix buffer. /// /// @param[in] m_idx Row index. -/// @param[in] stride Row stride in bytes. +/// @param[in] k Number of columns. /// /// @return The offset in bytes to the data element. -size_t kai_get_lhs_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(size_t m_idx, size_t stride); +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(size_t m_idx, size_t k); /// Gets the offset in bytes to the data element in the packed RHS matrix buffer. /// @@ -96,7 +96,7 @@ size_t kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla /// Runs the matrix multiplication microkernel followed by a clamp operation. /// -/// The pointer of each buffers (LHS, packed RHS and output) needs to be added with offset +/// The pointer of each buffers (packed LHS, packed RHS and output) needs to be added with offset /// calculated using the following functions: /// /// * Packed LHS: @ref kai_get_lhs_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla. @@ -106,7 +106,7 @@ size_t kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla /// @param[in] m Number of output rows to be computed. /// @param[in] n Number of output columns to be computed. /// @param[in] k Common dimension of the LHS and RHS operand. -/// @param[in] lhs LHS matrix buffer. +/// @param[in] lhs_packed Packed LHS buffer. /// @param[in] rhs_packed Packed RHS buffer. /// @param[out] dst Output matrix buffer. /// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p_interface.h b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h similarity index 90% rename from kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p_interface.h rename to kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h index 8eb1969a..768f63f5 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p_interface.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h @@ -6,7 +6,7 @@ #pragma once #if !defined(__aarch64__) || !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) -#error This file must be compiled for AArch64, FEAT_BF16. +#error This file must be compiled for AArch64, FEAT_BF16 #else // Architectural features check. #include @@ -25,7 +25,7 @@ typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_mr_func_t)(void); typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_nr_func_t)(void); typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_kr_func_t)(void); typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_sr_func_t)(void); -typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_lhs_offset_func_t)(size_t m_idx, size_t lhs_stride); +typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_lhs_packed_offset_func_t)(size_t m_idx, size_t k); typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_rhs_packed_offset_func_t)(size_t n_idx, size_t k); typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_dst_offset_func_t)(size_t m_idx, size_t n_idx, size_t dst_stride); typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_dst_size_func_t)(size_t m, size_t n); @@ -43,7 +43,7 @@ struct kai_matmul_clamp_f32_bf16p_bf16p_ukernel { kai_matmul_clamp_f32_bf16p_bf16p_get_nr_func_t get_nr; kai_matmul_clamp_f32_bf16p_bf16p_get_nr_func_t get_kr; kai_matmul_clamp_f32_bf16p_bf16p_get_sr_func_t get_sr; - kai_matmul_clamp_f32_bf16p_bf16p_get_lhs_offset_func_t get_lhs_packed_offset; + kai_matmul_clamp_f32_bf16p_bf16p_get_lhs_packed_offset_func_t get_lhs_packed_offset; kai_matmul_clamp_f32_bf16p_bf16p_get_rhs_packed_offset_func_t get_rhs_packed_offset; kai_matmul_clamp_f32_bf16p_bf16p_get_dst_offset_func_t get_dst_offset; kai_matmul_clamp_f32_bf16p_bf16p_get_dst_size_func_t get_dst_size; diff --git a/kai/ukernels/matmul/pack/kai_lhs_pack_f32p8x4_bf16_neon.h b/kai/ukernels/matmul/pack/kai_lhs_pack_f32p8x4_bf16_neon.h index dd47ba88..d43838bc 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_pack_f32p8x4_bf16_neon.h +++ b/kai/ukernels/matmul/pack/kai_lhs_pack_f32p8x4_bf16_neon.h @@ -14,14 +14,62 @@ extern "C" { #include "kai/kai_common.h" +/// Gets m step value. +/// +/// The starting row index must be divisible by `m_step`. +/// +/// @param[in] mr Number of rows to be interleaved. +/// +/// @return The m step value. size_t kai_get_m_step_lhs_pack_f32p8x4_bf16_neon(size_t mr); +/// Gets the offset in bytes to the data element in the LHS buffer. +/// +/// @param[in] m_idx Row index. +/// @param[in] lhs_stride Row stride in bytes. +/// +/// @return The offset in bytes to the data element. size_t kai_get_lhs_offset_lhs_pack_f32p8x4_bf16_neon(size_t m_idx, size_t lhs_stride); +/// Gets the offset in bytes to the data element in the packed LHS buffer. +/// +/// @param[in] m_idx Row index in the unpacked LHS matrix. +/// @param[in] k Number of columns in the unpacked LHS matrix. +/// @param[in] mr Number of rows to be interleaved. +/// @param[in] kr Number of columns to be interleaved. +/// @param[in] sr Unused. Must be 1. +/// +/// @return The offset in bytes to the data element. size_t kai_get_lhs_packed_offset_lhs_pack_f32p8x4_bf16_neon(size_t m_idx, size_t k); +/// Gets the size in bytes of the packed LHS buffer. +/// +/// @param[in] m Number of rows in the unpacked LHS matrix. +/// @param[in] k Number of columns in the unpacked LHS matrix. +/// @param[in] mr Number of rows to be interleaved. +/// @param[in] kr Number of columns to be interleaved. +/// @param[in] sr Unused. Must be 1. +/// +/// @return The size in bytes of the packed LHS buffer. size_t kai_get_lhs_packed_size_lhs_pack_f32p8x4_bf16_neon(size_t m, size_t k, size_t mr, size_t kr, size_t sr); +/// Runs the LHS packing function for matrix multiplication. +/// +/// The pointer of each buffers (LHS and packed LHS) needs to be added with offset +/// calculated using the following functions: +/// +/// * LHS: @ref kai_get_lhs_offset_lhs_pack_f32p8x4_bf16_neon. +/// * Packed LHS: @ref kai_get_lhs_packed_offset_lhs_pack_f32p8x4_bf16_neon. +/// +/// @param[in] m Number of rows of the unpacked LHS matrix. +/// @param[in] k Common dimension between the LHS and RHS matrix. +/// @param[in] mr Block size in M dimension. It must be 8. +/// @param[in] kr Block size in K dimension. It must be 4. +/// @param[in] sr Number of kr splits. It must be 1. +/// @param[in] m_idx_start Unused. Must be 0. +/// @param[in] lhs LHS matrix data buffer. +/// @param[in] lhs_stride Row stride in bytes of the LHS matrix. +/// @param[out] lhs_packed Packed RHS matrix. void kai_run_lhs_pack_f32p8x4_bf16_neon( size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride, void* lhs_packed); diff --git a/test/common/MatMulMethod.hpp b/test/common/MatMulMethod.hpp deleted file mode 100644 index 127e4910..00000000 --- a/test/common/MatMulMethod.hpp +++ /dev/null @@ -1,329 +0,0 @@ -// -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates -// -// SPDX-License-Identifier: Apache-2.0 -// -#pragma once - -#include -#include -#include -#include - -#include "kai/kai_common.h" -#include "test/common/data_format.hpp" -#include "test/common/float16.hpp" - -namespace kai::test { - -// NOLINTBEGIN(misc-non-private-member-variables-in-classes) - -/// Matrix multiplication method. -struct MatMulMethod { - std::string_view name; ///< Name of matmul method. - - size_t m0; ///< Block size in M dimension. - size_t n0; ///< Block size in N dimension. - - bool lhs_transposed; ///< LHS matrix is transposed. - bool rhs_transposed; ///< RHS matrix is transposed. - - bool is_sme2; ///< Test is a sme2 test - - DataFormat dst_format; ///< Data format of the destination matrix. - DataFormat lhs_format; ///< Data format of the LHS matrix. - DataFormat packed_lhs_format; ///< Data format of the packed LHS matrix. - DataFormat rhs_format; ///< Data format of the RHS matrix. - DataFormat packed_rhs_format; ///< Data format of the packed RHS matrix. - DataFormat bias_format; ///< Data format of the bias vector. - - /// Gets mr value. - /// - /// This is the packing parameter which must be used to pack the LHS matrix (if necessary). - /// - /// @return The mr value. - std::function fn_get_mr; - - /// Gets nr value. - /// - /// This is the packing parameter which must be used to pack the RHS matrix (if necessary). - /// - /// @return The nr value. - std::function fn_get_nr; - - /// Gets kr value. - /// - /// This is the packing parameter which must be used to pack the LHS and RHS matrix (if necessary). - /// - /// @return The kr value. - std::function fn_get_kr; - - /// Gets sr value. - /// - /// This is the packing parameter which must be used to pack the RHS matrix. - /// - /// @return The sr value. - std::function fn_get_sr; - - /// Gets m step value for main kernel. - /// - /// The starting row index must be divisible by `m_step`. - /// - /// @return The m step value. - std::function fn_get_main_m_step; - - /// Gets n step value for RHS packing kernel. - /// - /// The starting row index must be divisible by `n_step`. - /// - /// @return The n step value. - std::function fn_get_pack_rhs_n_step; - - /// Gets n step value for main kernel. - /// - /// The starting column index must be divisible by `n_step`. - /// - /// @return The n step value. - std::function fn_get_main_n_step; - - /// Gets the offset in bytes of the LHS matrix. - /// - /// @param[in] m_idx Coordinate of the matrix in M dimension. - /// @param[in] stride Row stride in bytes. - /// - /// @return The offset in bytes. - std::function fn_get_lhs_offset; - - /// Gets the size in bytes of the packed LHS matrix. - /// - /// @param[in] m Number of rows in the unpacked LHS matrix. - /// @param[in] k Number of columns in the unpacked LHS matrix. - /// @param[in] mr Number of rows to be interleaved. - /// @param[in] kr Unused. Must be 1. - /// @param[in] sr Unused. Must be 1. - /// - /// @return The size in bytes. - std::function fn_get_packed_lhs_size; - - /// Gets the offset in bytes of the packed LHS matrix. - /// - /// @param[in] m_idx Coordinate of the matrix in M dimension. - /// @param[in] k Size of the matrix in K dimension. - /// - /// @return The offset in bytes. - std::function fn_get_packed_lhs_offset; - - /// Preprocesses the LHS matrix. - /// - /// @param[in] m Number of rows of the unpacked LHS matrix. - /// @param[in] k Common dimension between the LHS and RHS matrix. - /// @param[in] mr Block size in M dimension. It must be {{ kernel.interleave_by }}VL. - /// @param[in] kr Block size in K dimension. It must be {{ kernel.block_by }}. - /// @param[in] sr Number of kr splits. It must be 1. - /// @param[in] m_idx_start Unused. Must be 0. - /// @param[in] lhs LHS matrix data buffer. - /// @param[in] lhs_stride Row stride in bytes of the LHS matrix. - /// @param[out] lhs_packed Packed RHS matrix. - std::function - fn_pack_lhs; - - /// Gets a value indicating whether LHS packing is needed. - [[nodiscard]] bool is_pack_lhs_needed() const { - return fn_pack_lhs != nullptr; - } - - /// Gets the offset in bytes of the RHS matrix. - /// - /// @param[in] n_idx Coordinate of the matrix in N dimension. - /// - /// @return The offset in bytes. - std::function fn_get_rhs_offset; - - /// Gets the size in bytes of the packed RHS matrix. - /// - /// @param[in] n Size of the matrix in N dimension. - /// @param[in] k Size of the matrix in K dimension. - /// - /// @return The size in bytes. - std::function fn_get_packed_rhs_size; - - /// Gets the offset in bytes of the packed RHS matrix in the RHS packing kernel - /// - /// @param[in] n_idx Coordinate of the matrix in N dimension. - /// @param[in] k Size of the matrix in K dimension. - /// - /// @return The offset in bytes. - std::function fn_get_pack_rhs_packed_rhs_offset; - - /// Gets the offset in bytes of the packed RHS matrix in the main kernel. - /// - /// @param[in] n_idx Coordinate of the matrix in N dimension. - /// @param[in] k Size of the matrix in K dimension. - /// - /// @return The offset in bytes. - std::function fn_get_main_packed_rhs_offset; - - std::function - fn_pack_rhs; - - /// Gets the offset in bytes to the data element in the bias buffer. - /// - /// @param[in] n_idx Column index. - /// - /// @return The offset in bytes to the data element. - std::function fn_get_bias_offset; - - /// Gets the offset in bytes to the data element in the destination matrix buffer. - /// - /// @param[in] m_idx Row index. - /// @param[in] n_idx Column index. - /// @param[in] stride Row stride in bytes. - /// - /// @return The offset in bytes to the data element. - std::function fn_get_dst_offset; - - /// Gets the size in bytes of the destination matrix buffer. - /// - /// @param[in] m Number of rows. - /// @param[in] n Number of columns. - /// - /// @return The size in bytes of the destination matrix buffer. - std::function fn_get_dst_size; - - /// Performs F16 or F32 matrix multiplication with RHS packing - /// followed by clamp operation. - /// - /// @param[in] m Size of the matrix in M dimension. - /// @param[in] n Size of the matrix in N dimension. - /// @param[in] k Size of the matrix in K dimension. - /// @param[in] lhs LHS data buffer. - /// @param[in] packed_rhs Packed RHS data buffer. - /// @param[out] dst Output data buffer. - /// @param[in] lhs_stride LHS row stride. - /// @param[in] dst_stride_row Output row stride. - /// @param[in] dst_stride_col Output column stride. - /// @param[in] clamp_min Lower bound of the output data. - /// @param[in] clamp_max Upper bound of the output data. - std::function - fn_matmul_f16_f16_f16p; - - std::function - fn_matmul_f32_f32_f32p; - - /// Performs BF16 matrix multiplication with LHS and RHS packing - /// followed by clamp operation. - /// - /// @param[in] m Size of the matrix in M dimension. - /// @param[in] n Size of the matrix in N dimension. - /// @param[in] k Size of the matrix in K dimension. - /// @param[in] packed_lhs Packed LHS data buffer. - /// @param[in] packed_rhs Packed RHS data buffer. - /// @param[out] dst Output data buffer. - /// @param[in] dst_stride_row Output row stride. - /// @param[in] dst_stride_col Output column stride. - /// @param[in] clamp_min Lower bound of the output data. - /// @param[in] clamp_max Upper bound of the output data. - std::function - fn_matmul_f32_bf16p_bf16p; - - /// Performs F32 matrix multiplication with LHS & RHS packing - /// followed by clamp operation. - /// - /// @param[in] m Number of output rows to be computed. - /// @param[in] n Number of output columns to be computed. - /// @param[in] k Common dimension of the LHS and RHS operands. - /// @param[in] packed_lhs Packed LHS matrix buffer. - /// @param[in] packed_rhs Packed RHS matrix buffer. - /// @param[out] dst Output matrix buffer. - /// @param[in] dst_stride_row Row stride in bytes of the output matrix. - /// @param[in] dst_stride_col Column stride in bytes of the output matrix. - /// @param[in] clamp_min Minimum value to clamp the final result. - /// @param[in] clamp_max Maximum value to clamp the final result. - std::function - fn_matmul_f32_f32p_f32p; - - /// Gets a value indicating whether pre-processing the RHS matrix is needed. - [[nodiscard]] bool is_pack_rhs_needed() const { - return fn_pack_rhs != nullptr; - } - - /// Preprocesses the RHS matrix. - /// - /// @param[in] n Size of the matrix in N dimension. - /// @param[in] k Size of the matrix in K dimension. - /// @param[in] rhs RHS data buffer. - /// @param[in] rhs_row_stride RHS row stride. - /// @param[in] bias Bias data buffer. - /// @param[in] scale Quantization scales data buffer. - /// @param[out] packed_rhs Packed RHS data buffer. - void pack_rhs( - size_t n, size_t k, const void* rhs, size_t rhs_row_stride, const void* bias, const void* scale, - void* packed_rhs) const { - KAI_UNUSED(n); - KAI_UNUSED(k); - KAI_UNUSED(rhs); - KAI_UNUSED(rhs_row_stride); - KAI_UNUSED(bias); - KAI_UNUSED(scale); - - if (fn_pack_rhs != nullptr) { - fn_pack_rhs( - 1, n, k, fn_get_nr(), fn_get_kr(), fn_get_sr(), rhs_row_stride, rhs, bias, nullptr, packed_rhs, 0, - nullptr); - } else { - KAI_ERROR("RHS pre-processing is not supported!"); - } - } - - [[nodiscard]] bool has_main_kernel() const { - return fn_matmul_f16_f16_f16p != nullptr || fn_matmul_f32_f32p_f32p != nullptr || - fn_matmul_f32_f32_f32p != nullptr || fn_matmul_f32_bf16p_bf16p != nullptr; - } - - void main_kernel( - size_t m, size_t n, size_t k, const void* lhs, const void* rhs, const void* bias, void* dst, size_t lhs_stride, - size_t rhs_stride, size_t dst_stride, float clamp_min, float clamp_max) const { - KAI_UNUSED(bias); - KAI_UNUSED(rhs_stride); - - if (fn_matmul_f16_f16_f16p) { - fn_matmul_f16_f16_f16p( - m, n, k, lhs, lhs_stride, rhs, dst, dst_stride, sizeof(Float16), clamp_min, - static_cast(clamp_max)); - } else if (fn_matmul_f32_f32_f32p) { - fn_matmul_f32_f32_f32p( - m, n, k, lhs, lhs_stride, rhs, dst, dst_stride, sizeof(float), clamp_min, - static_cast(clamp_max)); - } else if (fn_matmul_f32_f32p_f32p) { - fn_matmul_f32_f32p_f32p(m, n, k, lhs, rhs, dst, dst_stride, sizeof(float), clamp_min, clamp_max); - } else if (fn_matmul_f32_bf16p_bf16p) { - fn_matmul_f32_bf16p_bf16p(m, n, k, lhs, rhs, dst, dst_stride, sizeof(float), clamp_min, clamp_max); - } else { - KAI_ERROR("Main kernel is not available!"); - } - } -}; - -// NOLINTEND(misc-non-private-member-variables-in-classes) -} // namespace kai::test diff --git a/test/common/matmul_test_common.cpp b/test/common/matmul_test_common.cpp index 905450fb..73d41c09 100644 --- a/test/common/matmul_test_common.cpp +++ b/test/common/matmul_test_common.cpp @@ -7,7 +7,6 @@ #include "matmul_test_common.hpp" #include -#include namespace kai::test { void PrintTo(const MatMulTestParams& param, std::ostream* os) { diff --git a/test/common/matmul_test_common.hpp b/test/common/matmul_test_common.hpp index 767117c5..90eb95d9 100644 --- a/test/common/matmul_test_common.hpp +++ b/test/common/matmul_test_common.hpp @@ -3,11 +3,17 @@ // // SPDX-License-Identifier: Apache-2.0 // +#pragma once #include +#include +#include +#include #include -#include "test/common/MatMulMethod.hpp" +#include "kai/kai_common.h" +#include "test/common/data_format.hpp" +#include "test/common/float16.hpp" #include "test/common/matrix_portion.hpp" namespace kai::test { @@ -18,6 +24,320 @@ struct MatMulShape { size_t k; ///< LHS width and RHS height. }; +// NOLINTBEGIN(misc-non-private-member-variables-in-classes) + +/// Matrix multiplication method. +struct MatMulMethod { + std::string_view name; ///< Name of matmul method. + + size_t m0; ///< Block size in M dimension. + size_t n0; ///< Block size in N dimension. + + bool lhs_transposed; ///< LHS matrix is transposed. + bool rhs_transposed; ///< RHS matrix is transposed. + + DataFormat dst_format; ///< Data format of the destination matrix. + DataFormat lhs_format; ///< Data format of the LHS matrix. + DataFormat packed_lhs_format; ///< Data format of the packed LHS matrix. + DataFormat rhs_format; ///< Data format of the RHS matrix. + DataFormat packed_rhs_format; ///< Data format of the packed RHS matrix. + DataFormat bias_format; ///< Data format of the bias vector. + + /// Check if CPU supports required features. + /// + /// @return Supported (true) or not supported (false). + std::function fn_is_supported; + + /// Gets mr value. + /// + /// This is the packing parameter which must be used to pack the LHS matrix (if necessary). + /// + /// @return The mr value. + std::function fn_get_mr; + + /// Gets nr value. + /// + /// This is the packing parameter which must be used to pack the RHS matrix (if necessary). + /// + /// @return The nr value. + std::function fn_get_nr; + + /// Gets kr value. + /// + /// This is the packing parameter which must be used to pack the LHS and RHS matrix (if necessary). + /// + /// @return The kr value. + std::function fn_get_kr; + + /// Gets sr value. + /// + /// This is the packing parameter which must be used to pack the RHS matrix. + /// + /// @return The sr value. + std::function fn_get_sr; + + /// Gets m step value for main kernel. + /// + /// The starting row index must be divisible by `m_step`. + /// + /// @return The m step value. + std::function fn_get_main_m_step; + + /// Gets n step value for RHS packing kernel. + /// + /// The starting row index must be divisible by `n_step`. + /// + /// @return The n step value. + std::function fn_get_pack_rhs_n_step; + + /// Gets n step value for main kernel. + /// + /// The starting column index must be divisible by `n_step`. + /// + /// @return The n step value. + std::function fn_get_main_n_step; + + /// Gets the offset in bytes of the LHS matrix. + /// + /// @param[in] m_idx Coordinate of the matrix in M dimension. + /// @param[in] stride Row stride in bytes. + /// + /// @return The offset in bytes. + std::function fn_get_lhs_offset; + + /// Gets the size in bytes of the packed LHS matrix. + /// + /// @param[in] m Number of rows in the unpacked LHS matrix. + /// @param[in] k Number of columns in the unpacked LHS matrix. + /// @param[in] mr Number of rows to be interleaved. + /// @param[in] kr Unused. Must be 1. + /// @param[in] sr Unused. Must be 1. + /// + /// @return The size in bytes. + std::function fn_get_packed_lhs_size; + + /// Gets the offset in bytes of the packed LHS matrix. + /// + /// @param[in] m_idx Coordinate of the matrix in M dimension. + /// @param[in] k Size of the matrix in K dimension. + /// + /// @return The offset in bytes. + std::function fn_get_packed_lhs_offset; + + /// Preprocesses the LHS matrix. + /// + /// @param[in] m Number of rows of the unpacked LHS matrix. + /// @param[in] k Common dimension between the LHS and RHS matrix. + /// @param[in] mr Block size in M dimension. It must be {{ kernel.interleave_by }}VL. + /// @param[in] kr Block size in K dimension. It must be {{ kernel.block_by }}. + /// @param[in] sr Number of kr splits. It must be 1. + /// @param[in] m_idx_start Unused. Must be 0. + /// @param[in] lhs LHS matrix data buffer. + /// @param[in] lhs_stride Row stride in bytes of the LHS matrix. + /// @param[out] lhs_packed Packed RHS matrix. + std::function + fn_pack_lhs; + + /// Gets a value indicating whether LHS packing is needed. + [[nodiscard]] bool is_pack_lhs_needed() const { + return fn_pack_lhs != nullptr; + } + + /// Gets the offset in bytes of the RHS matrix. + /// + /// @param[in] n_idx Coordinate of the matrix in N dimension. + /// + /// @return The offset in bytes. + std::function fn_get_rhs_offset; + + /// Gets the size in bytes of the packed RHS matrix. + /// + /// @param[in] n Size of the matrix in N dimension. + /// @param[in] k Size of the matrix in K dimension. + /// + /// @return The size in bytes. + std::function fn_get_packed_rhs_size; + + /// Gets the offset in bytes of the packed RHS matrix in the RHS packing kernel + /// + /// @param[in] n_idx Coordinate of the matrix in N dimension. + /// @param[in] k Size of the matrix in K dimension. + /// + /// @return The offset in bytes. + std::function fn_get_pack_rhs_packed_rhs_offset; + + /// Gets the offset in bytes of the packed RHS matrix in the main kernel. + /// + /// @param[in] n_idx Coordinate of the matrix in N dimension. + /// @param[in] k Size of the matrix in K dimension. + /// + /// @return The offset in bytes. + std::function fn_get_main_packed_rhs_offset; + + std::function + fn_pack_rhs; + + /// Gets the offset in bytes to the data element in the bias buffer. + /// + /// @param[in] n_idx Column index. + /// + /// @return The offset in bytes to the data element. + std::function fn_get_bias_offset; + + /// Gets the offset in bytes to the data element in the destination matrix buffer. + /// + /// @param[in] m_idx Row index. + /// @param[in] n_idx Column index. + /// @param[in] stride Row stride in bytes. + /// + /// @return The offset in bytes to the data element. + std::function fn_get_dst_offset; + + /// Gets the size in bytes of the destination matrix buffer. + /// + /// @param[in] m Number of rows. + /// @param[in] n Number of columns. + /// + /// @return The size in bytes of the destination matrix buffer. + std::function fn_get_dst_size; + + /// Performs F16 or F32 matrix multiplication with RHS packing + /// followed by clamp operation. + /// + /// @param[in] m Size of the matrix in M dimension. + /// @param[in] n Size of the matrix in N dimension. + /// @param[in] k Size of the matrix in K dimension. + /// @param[in] lhs LHS data buffer. + /// @param[in] packed_rhs Packed RHS data buffer. + /// @param[out] dst Output data buffer. + /// @param[in] lhs_stride LHS row stride. + /// @param[in] dst_stride_row Output row stride. + /// @param[in] dst_stride_col Output column stride. + /// @param[in] clamp_min Lower bound of the output data. + /// @param[in] clamp_max Upper bound of the output data. + std::function + fn_matmul_f16_f16_f16p = nullptr; + + std::function + fn_matmul_f32_f32_f32p = nullptr; + + /// Performs BF16 matrix multiplication with LHS and RHS packing + /// followed by clamp operation. + /// + /// @param[in] m Size of the matrix in M dimension. + /// @param[in] n Size of the matrix in N dimension. + /// @param[in] k Size of the matrix in K dimension. + /// @param[in] packed_lhs Packed LHS data buffer. + /// @param[in] packed_rhs Packed RHS data buffer. + /// @param[out] dst Output data buffer. + /// @param[in] dst_stride_row Output row stride. + /// @param[in] dst_stride_col Output column stride. + /// @param[in] clamp_min Lower bound of the output data. + /// @param[in] clamp_max Upper bound of the output data. + std::function + fn_matmul_f32_bf16p_bf16p = nullptr; + + /// Performs F32 matrix multiplication with LHS & RHS packing + /// followed by clamp operation. + /// + /// @param[in] m Number of output rows to be computed. + /// @param[in] n Number of output columns to be computed. + /// @param[in] k Common dimension of the LHS and RHS operands. + /// @param[in] packed_lhs Packed LHS matrix buffer. + /// @param[in] packed_rhs Packed RHS matrix buffer. + /// @param[out] dst Output matrix buffer. + /// @param[in] dst_stride_row Row stride in bytes of the output matrix. + /// @param[in] dst_stride_col Column stride in bytes of the output matrix. + /// @param[in] clamp_min Minimum value to clamp the final result. + /// @param[in] clamp_max Maximum value to clamp the final result. + std::function + fn_matmul_f32_f32p_f32p = nullptr; + + /// Gets a value indicating whether pre-processing the RHS matrix is needed. + [[nodiscard]] bool is_pack_rhs_needed() const { + return fn_pack_rhs != nullptr; + } + + /// Preprocesses the RHS matrix. + /// + /// @param[in] n Size of the matrix in N dimension. + /// @param[in] k Size of the matrix in K dimension. + /// @param[in] rhs RHS data buffer. + /// @param[in] rhs_row_stride RHS row stride. + /// @param[in] bias Bias data buffer. + /// @param[in] scale Quantization scales data buffer. + /// @param[out] packed_rhs Packed RHS data buffer. + void pack_rhs( + size_t n, size_t k, const void* rhs, size_t rhs_row_stride, const void* bias, const void* scale, + void* packed_rhs) const { + KAI_UNUSED(n); + KAI_UNUSED(k); + KAI_UNUSED(rhs); + KAI_UNUSED(rhs_row_stride); + KAI_UNUSED(bias); + KAI_UNUSED(scale); + + if (fn_pack_rhs != nullptr) { + fn_pack_rhs( + 1, n, k, fn_get_nr(), fn_get_kr(), fn_get_sr(), rhs_row_stride, rhs, bias, nullptr, packed_rhs, 0, + nullptr); + } else { + KAI_ERROR("RHS pre-processing is not supported!"); + } + } + + [[nodiscard]] bool has_main_kernel() const { + return fn_matmul_f16_f16_f16p != nullptr || fn_matmul_f32_f32p_f32p != nullptr || + fn_matmul_f32_f32_f32p != nullptr || fn_matmul_f32_bf16p_bf16p != nullptr; + } + + void main_kernel( + size_t m, size_t n, size_t k, const void* lhs, const void* rhs, const void* bias, void* dst, size_t lhs_stride, + size_t rhs_stride, size_t dst_stride, float clamp_min, float clamp_max) const { + KAI_UNUSED(bias); + KAI_UNUSED(rhs_stride); + + if (fn_matmul_f16_f16_f16p) { + fn_matmul_f16_f16_f16p( + m, n, k, lhs, lhs_stride, rhs, dst, dst_stride, sizeof(Float16), clamp_min, + static_cast(clamp_max)); + } else if (fn_matmul_f32_f32_f32p) { + fn_matmul_f32_f32_f32p( + m, n, k, lhs, lhs_stride, rhs, dst, dst_stride, sizeof(float), clamp_min, + static_cast(clamp_max)); + } else if (fn_matmul_f32_f32p_f32p) { + fn_matmul_f32_f32p_f32p(m, n, k, lhs, rhs, dst, dst_stride, sizeof(float), clamp_min, clamp_max); + } else if (fn_matmul_f32_bf16p_bf16p) { + fn_matmul_f32_bf16p_bf16p(m, n, k, lhs, rhs, dst, dst_stride, sizeof(float), clamp_min, clamp_max); + } else { + KAI_ERROR("Main kernel is not available!"); + } + } +}; + +// NOLINTEND(misc-non-private-member-variables-in-classes) + /// Matrix multiplication test information. using MatMulTestParams = std::tuple; diff --git a/test/common/memory.hpp b/test/common/memory.hpp index 7ea0aa53..c856218f 100644 --- a/test/common/memory.hpp +++ b/test/common/memory.hpp @@ -7,8 +7,10 @@ #pragma once #include +#include #include +#include "kai/kai_common.h" #include "test/common/bfloat16.hpp" #include "test/common/int4.hpp" @@ -26,14 +28,6 @@ inline constexpr size_t size_in_bits = 4; template <> inline constexpr size_t size_in_bits = 4; -/// TODO: Move this -inline float bf16_to_float(uint16_t v) { - const uint32_t lv = (v << 16); - float fp; - memcpy(&fp, &lv, sizeof(lv)); - return fp; -} - /// Reads the array at the specified index. /// /// @param[in] array Data buffer. @@ -50,7 +44,7 @@ T read_array(const void* array, size_t index) { return index % 2 == 0 ? lo : hi; } else if constexpr (std::is_same_v) { uint16_t raw_value = reinterpret_cast(array)[index]; - return BFloat16(bf16_to_float(raw_value)); + return BFloat16(kai_cast_f32_bf16(raw_value)); } else { return reinterpret_cast(array)[index]; } diff --git a/test/reference/pack.cpp b/test/reference/pack.cpp index 5c549664..ad123762 100644 --- a/test/reference/pack.cpp +++ b/test/reference/pack.cpp @@ -10,17 +10,14 @@ #include #include #include -#include #include #include "kai/kai_common.h" #include "test/common/bfloat16.hpp" #include "test/common/data_format.hpp" #include "test/common/data_type.hpp" -#include "test/common/float16.hpp" #include "test/common/memory.hpp" #include "test/common/round.hpp" -#include "test/reference/quantize.hpp" namespace kai::test { @@ -37,9 +34,7 @@ std::vector pack_block( const auto dst_bytes = round_up_multiple(full_height, block_height) * round_up_multiple(full_width, block_width) * dst_esize; - std::vector dst; - dst.resize(dst_bytes); - memset(dst.data(), 0, dst_bytes); + std::vector dst(dst_bytes, 0); const auto* src_ptr = reinterpret_cast(src); auto* dst_ptr = dst.data(); @@ -106,9 +101,7 @@ std::vector pack_bias_per_row( const auto group_bytes = group_bias_bytes + group_num_blocks * block_data_bytes; const auto dst_bytes = num_groups * group_bytes; - std::vector dst; - dst.resize(dst_bytes); - memset(dst.data(), 0, dst_bytes); + std::vector dst(dst_bytes, 0); const auto* src_ptr = reinterpret_cast(src); const auto* bias_ptr = reinterpret_cast(bias); @@ -147,8 +140,8 @@ std::vector pack_bias_per_row( x_element) * src_esize; - uint16_t src_value = convert(src_ptr_elm, src_dtype, dst_dtype); - memcpy(dst_ptr, &src_value, dst_esize); + const uint16_t dst_value = convert(src_ptr_elm, src_dtype, dst_dtype); + memcpy(dst_ptr, &dst_value, dst_esize); } } diff --git a/test/reference/pack.hpp b/test/reference/pack.hpp index 10d76a7f..128ad040 100644 --- a/test/reference/pack.hpp +++ b/test/reference/pack.hpp @@ -22,8 +22,8 @@ class DataFormat; /// @param[in] height Number of rows of the source matrix. /// @param[in] width Number of columns of the source matrix. std::vector pack( - const DataFormat& dst_format, const void* src, const void* scales, const void* zero_points, - const DataFormat& src_format, size_t height, size_t width); + const DataFormat& dst_format, const void* src, const void* scales, const void* bias, const DataFormat& src_format, + size_t height, size_t width); /// Packs the quantized data and the quantization scale into a single buffer. /// diff --git a/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp index 104a507a..7b235559 100644 --- a/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp +++ b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp @@ -20,16 +20,13 @@ #include #include "kai/kai_common.h" -#include "test/common/MatMulMethod.hpp" #include "test/common/compare.hpp" #include "test/common/cpu_info.hpp" #include "test/common/data_format.hpp" #include "test/common/data_type.hpp" -#include "test/common/float16.hpp" #include "test/common/matmul_test_common.hpp" #include "test/common/matrix_portion.hpp" #include "test/common/printer.hpp" -#include "test/common/sme.hpp" #include "test/reference/fill.hpp" #include "test/reference/matmul.hpp" #include "test/reference/pack.hpp" @@ -41,7 +38,8 @@ namespace kai::test { /// List of supported matrix multiplication methods. -static const std::array matmul_methods = { +namespace { +const std::array matmul_methods = { MatMulMethod{ .name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla", @@ -51,8 +49,6 @@ static const std::array matmul_methods = { .lhs_transposed = false, .rhs_transposed = false, - .is_sme2 = false, - .dst_format = DataFormat(DataType::FP32), .lhs_format = DataFormat(DataType::FP32), .packed_lhs_format = @@ -61,6 +57,7 @@ static const std::array matmul_methods = { .packed_rhs_format = DataFormat( DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4), .bias_format = DataFormat(DataType::FP32), + .fn_is_supported = cpu_has_bf16, .fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, .fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, @@ -99,8 +96,6 @@ static const std::array matmul_methods = { .lhs_transposed = false, .rhs_transposed = false, - .is_sme2 = false, - .dst_format = DataFormat(DataType::FP32), .lhs_format = DataFormat(DataType::FP32), .packed_lhs_format = @@ -109,6 +104,7 @@ static const std::array matmul_methods = { .packed_rhs_format = DataFormat( DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4), .bias_format = DataFormat(DataType::UNKNOWN), + .fn_is_supported = cpu_has_bf16, .fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, .fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, @@ -138,12 +134,13 @@ static const std::array matmul_methods = { .fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, }}; +} /// Matrix multiplication test fixture. class MatMulTestBf16 : public testing::TestWithParam { private: /// Unique ID: m, n, k - using TestDataId = std::tuple; + using TestDataId = std::tuple; protected: /// Cached test data that is shared between multiple test case. @@ -203,7 +200,7 @@ protected: } std::vector packed_rhs; - packed_rhs.resize(method.packed_rhs_format.default_size_in_bytes(rhs_h, rhs_w)); + packed_rhs.resize(method.fn_get_packed_rhs_size(rhs_w, rhs_h)); if (has_rhs_pack) { const auto ref_rhs_row_stride = method.rhs_format.default_row_stride(rhs_w); @@ -249,7 +246,7 @@ TEST_P(MatMulTestBf16, Output) { const auto& [method, info, portion] = GetParam(); const auto& data = test_data(); - if (method.is_sme2 && !cpu_has_sme2()) { + if (method.fn_is_supported && !method.fn_is_supported()) { GTEST_SKIP(); } @@ -372,5 +369,4 @@ INSTANTIATE_TEST_SUITE_P( MatrixPortion(0.4, 0.5, 0.6, 0.8) // Somewhere Middle )), testing::PrintToStringParamName()); - } // namespace kai::test diff --git a/test/tests/matmul_test.cpp b/test/tests/matmul_test.cpp index 90f1f1df..752a3701 100644 --- a/test/tests/matmul_test.cpp +++ b/test/tests/matmul_test.cpp @@ -21,12 +21,10 @@ #include #include "kai/kai_common.h" -#include "test/common/MatMulMethod.hpp" #include "test/common/compare.hpp" #include "test/common/cpu_info.hpp" #include "test/common/data_format.hpp" #include "test/common/data_type.hpp" -#include "test/common/float16.hpp" #include "test/common/matmul_test_common.hpp" #include "test/common/matrix_portion.hpp" #include "test/common/printer.hpp" @@ -202,7 +200,7 @@ static const std::array matmul_methods = { class MatMulTest : public testing::TestWithParam { private: /// Unique ID: m, n, k, method_id. - using TestDataId = std::tuple; + using TestDataId = std::tuple; protected: /// Cached test data that is shared between multiple test case. -- GitLab From e383c89c4886c60ea522a84c23eed680d7b0fd72 Mon Sep 17 00:00:00 2001 From: Gunes Bayir Date: Mon, 14 Oct 2024 22:45:06 +0100 Subject: [PATCH 05/10] Rename Lhs/Rhs packing functions in a generic manner Signed-off-by: Gunes Bayir --- CMakeLists.txt | 4 +- .../CMakeLists.txt | 4 +- .../matmul_clamp_f32_bf16p_bf16p.cpp | 18 ++--- kai/ukernels/matmul/BUILD.bazel | 16 +++-- ....c => kai_lhs_quant_pack_bf16p_f32_neon.c} | 52 ++++++-------- ....h => kai_lhs_quant_pack_bf16p_f32_neon.h} | 14 ++-- ...ai_rhs_quant_pack_bf16pbiasf32_f32_neon.c} | 35 ++++------ ...ai_rhs_quant_pack_bf16pbiasf32_f32_neon.h} | 26 ++++--- test/common/matmul_test_common.hpp | 10 +++ .../matmul_clamp_f32_bf16p_bf16p_test.cpp | 70 +++++++++---------- 10 files changed, 122 insertions(+), 127 deletions(-) rename kai/ukernels/matmul/pack/{kai_lhs_pack_f32p8x4_bf16_neon.c => kai_lhs_quant_pack_bf16p_f32_neon.c} (85%) rename kai/ukernels/matmul/pack/{kai_lhs_pack_f32p8x4_bf16_neon.h => kai_lhs_quant_pack_bf16p_f32_neon.h} (79%) rename kai/ukernels/matmul/pack/{kai_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon.c => kai_rhs_quant_pack_bf16pbiasf32_f32_neon.c} (93%) rename kai/ukernels/matmul/pack/{kai_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon.h => kai_rhs_quant_pack_bf16pbiasf32_f32_neon.h} (68%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 569b71b0..da64f9af 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -89,8 +89,8 @@ set(KLEIDIAI_FILES_NEON_FP16 ) set(KLEIDIAI_FILES_NEON_BF16 - kai/ukernels/matmul/pack/kai_lhs_pack_f32p8x4_bf16_neon.c - kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon.c + kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p_f32_neon.c + kai/ukernels/matmul/pack/kai_rhs_quant_pack_bf16pbiasf32_f32_neon.c kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.c ) diff --git a/examples/matmul_clamp_f32_bf16p_bf16p/CMakeLists.txt b/examples/matmul_clamp_f32_bf16p_bf16p/CMakeLists.txt index ff2fd03e..c59da140 100644 --- a/examples/matmul_clamp_f32_bf16p_bf16p/CMakeLists.txt +++ b/examples/matmul_clamp_f32_bf16p_bf16p/CMakeLists.txt @@ -23,8 +23,8 @@ include_directories( add_executable(matmul_clamp_f32_bf16p_bf16p matmul_clamp_f32_bf16p_bf16p.cpp ${MATMUL_PATH}/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.c - ${MATMUL_PACK_PATH}/kai_lhs_pack_f32p8x4_bf16_neon.c - ${MATMUL_PACK_PATH}/kai_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon.c + ${MATMUL_PACK_PATH}/kai_lhs_quant_pack_bf16p_f32_neon.c + ${MATMUL_PACK_PATH}/kai_rhs_quant_pack_bf16pbiasf32_f32_neon.c ) target_compile_options(matmul_clamp_f32_bf16p_bf16p diff --git a/examples/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p.cpp b/examples/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p.cpp index 34b6db5e..60db9087 100644 --- a/examples/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p.cpp +++ b/examples/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p.cpp @@ -28,10 +28,10 @@ // Include micro-kernel variants #include "kai/kai_common.h" -#include "kai_lhs_pack_f32p8x4_bf16_neon.h" +#include "kai_lhs_quant_pack_bf16p_f32_neon.h" #include "kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.h" #include "kai_matmul_clamp_f32_bf16p_bf16p_interface.h" -#include "kai_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon.h" +#include "kai_rhs_quant_pack_bf16pbiasf32_f32_neon.h" inline float bf16_to_float(const bfloat16_t* v) { const uint16_t uint_rep = *reinterpret_cast(v); @@ -180,9 +180,9 @@ bool is_output_correct( int main() { // Parameters of the matrix multiplication. Change these values to see how the micro-kernels operate on different // sized matrices - const size_t M = 5; // Rows of LHS and DST matrices - const size_t N = 8; // Columns of RHS and DST matrices, and length of the Bias vector. - const size_t K = 7; // Columns of LHS, rows of RHS matrices + const size_t M = 25; // Rows of LHS and DST matrices + const size_t N = 28; // Columns of RHS and DST matrices, and length of the Bias vector. + const size_t K = 117; // Columns of LHS, rows of RHS matrices const size_t lhs_size = M * K; const size_t rhs_size = N * K; @@ -235,7 +235,7 @@ int main() { const size_t sr = ukernel.get_sr(); // In a single row, we pack nr bias values followed by K rows of nr RHS values - const size_t rhs_packed_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon(N, K); + const size_t rhs_packed_size = kai_get_rhs_packed_size_rhs_quant_pack_bf16pbiasf32_f32_neon(N, K, nr, kr); uint8_t* rhs_packed = new uint8_t[rhs_packed_size]; const size_t lhs_stride = K * sizeof(float); @@ -243,11 +243,11 @@ int main() { const size_t dst_stride_row = N * sizeof(float); const size_t dst_stride_col = sizeof(float); - const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_pack_f32p8x4_bf16_neon(M, K, mr, kr, sr); + const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p_f32_neon(M, K, mr, kr, sr); bfloat16_t* lhs_packed = new bfloat16_t[lhs_packed_size]; // Packing only needs to be performed once if the contents of the bias and RHS matrices are expected to be constant. - kai_run_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon( + kai_run_rhs_quant_pack_bf16pbiasf32_f32_neon( 1, N, K, nr, kr, sr, // Packing arguments rhs_stride, // RHS stride rhs, // RHS @@ -272,7 +272,7 @@ int main() { const auto timer_matmul_start = std::chrono::high_resolution_clock::now(); - kai_run_lhs_pack_f32p8x4_bf16_neon(M, K, mr, kr, sr, 0 /* m_idx_start */, lhs, lhs_stride, lhs_packed); + kai_run_lhs_quant_pack_bf16p_f32_neon(M, K, mr, kr, sr, 0 /* m_idx_start */, lhs, lhs_stride, lhs_packed); ukernel.run_matmul( M, N, K, // Dimensions diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 8fc817ab..8c7d11fa 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -177,9 +177,9 @@ kai_c_library( ) kai_c_library( - name = "lhs_pack_f32p8x4_bf16_neon", - srcs = ["pack/kai_lhs_pack_f32p8x4_bf16_neon.c"], - hdrs = ["pack/kai_lhs_pack_f32p8x4_bf16_neon.h"], + name = "lhs_quant_pack_bf16p_f32_neon", + srcs = ["pack/kai_lhs_quant_pack_bf16p_f32_neon.c"], + hdrs = ["pack/kai_lhs_quant_pack_bf16p_f32_neon.h"], cpu_uarch = kai_cpu_bf16(), ) @@ -191,9 +191,9 @@ kai_c_library( ) kai_c_library( - name = "rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon", - srcs = ["pack/kai_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon.c"], - hdrs = ["pack/kai_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon.h"], + name = "rhs_quant_pack_bf16pbiasf32_f32_neon", + srcs = ["pack/kai_rhs_quant_pack_bf16pbiasf32_f32_neon.c"], + hdrs = ["pack/kai_rhs_quant_pack_bf16pbiasf32_f32_neon.h"], cpu_uarch = kai_cpu_bf16(), ) @@ -359,18 +359,20 @@ kai_c_library( ":clamp_f32_qsi8d32p_qsi4c32p_i8mm", ":clamp_f32_qsi8d32p_qsi4c32p_interface", ":lhs_pack_f32p2vlx1_f32_sme", - ":lhs_pack_f32p8x4_bf16_neon", + ":lhs_quant_pack_bf16p_f32_neon", ":lhs_quant_pack_qai8dxp_f32", ":lhs_quant_pack_qsi8d32p_f32", ":rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon", ":rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme", ":rhs_pack_kxn_f32pb_f32_f32_16vlx1_sme", ":rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon", + "rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon", ":rhs_pack_kxn_f32pbiasf32_f32_f32_neon", ":rhs_pack_kxn_qsi4c32p_qsu4c32s1s0", ":rhs_pack_kxn_qsi4cxp_qs4cxs1s0", ":rhs_pack_nxk_qsi4c32p_qsu4c32s1s0", ":rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0", ":rhs_pack_nxk_qsi4cxp_qs4cxs1s0", + ":rhs_quant_pack_bf16pbiasf32_f32_neon", ], ) diff --git a/kai/ukernels/matmul/pack/kai_lhs_pack_f32p8x4_bf16_neon.c b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p_f32_neon.c similarity index 85% rename from kai/ukernels/matmul/pack/kai_lhs_pack_f32p8x4_bf16_neon.c rename to kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p_f32_neon.c index 32c845f7..cea886ed 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_pack_f32p8x4_bf16_neon.c +++ b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p_f32_neon.c @@ -8,66 +8,54 @@ #error This file must be compiled for AArch64, FEAT_BF16. #else // Architectural features check. +#define MAX_MR 8 + #include #include #include #include "kai/kai_common.h" -static const size_t kai_mr = 8; -static const size_t kai_kr = 4; -static const size_t kai_sr = 1; - -size_t kai_get_m_step_lhs_pack_f32p8x4_bf16_neon(size_t mr) { - KAI_ASSUME(mr == kai_mr); - KAI_UNUSED(mr); - - return kai_mr; +size_t kai_get_m_step_lhs_quant_pack_bf16p_f32_neon(size_t mr) { + return mr; } -size_t kai_get_lhs_offset_lhs_pack_f32p8x4_bf16_neon(size_t m_idx, size_t lhs_stride) { - KAI_ASSUME(m_idx % (kai_mr) == 0); - +size_t kai_get_lhs_offset_lhs_quant_pack_bf16p_f32_neon(size_t m_idx, size_t lhs_stride) { return m_idx * lhs_stride; } -size_t kai_get_lhs_packed_offset_lhs_pack_f32p8x4_bf16_neon(size_t m_idx, size_t k) { - KAI_ASSUME(m_idx % kai_mr == 0); +size_t kai_get_lhs_packed_offset_lhs_quant_pack_bf16p_f32_neon( + size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) { + KAI_UNUSED(sr); + KAI_ASSUME(m_idx % mr == 0); - return m_idx * kai_roundup(k, kai_kr) * sizeof(bfloat16_t); + return m_idx * kai_roundup(k, kr) * sizeof(uint16_t); } -size_t kai_get_lhs_packed_size_lhs_pack_f32p8x4_bf16_neon(size_t m, size_t k, size_t mr, size_t kr, size_t sr) { - KAI_ASSUME(mr == kai_mr); - KAI_ASSUME(kr == kai_kr); - KAI_ASSUME(sr == kai_sr); - - KAI_UNUSED(mr); - KAI_UNUSED(kr); +size_t kai_get_lhs_packed_size_lhs_quant_pack_bf16p_f32_neon(size_t m, size_t k, size_t mr, size_t kr, size_t sr) { KAI_UNUSED(sr); - return kai_roundup(m, kai_mr) * kai_roundup(k, kai_kr) * sizeof(bfloat16_t); + return kai_roundup(m, mr) * kai_roundup(k, kr) * sizeof(uint16_t); } -void kai_run_lhs_pack_f32p8x4_bf16_neon( - size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride, - void* lhs_packed) { - KAI_ASSUME(mr == kai_mr); - KAI_ASSUME(kr == kai_kr); - KAI_ASSUME(sr == kai_sr); +void kai_run_lhs_quant_pack_bf16p_f32_neon( + size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs, size_t lhs_stride, + uint16_t* lhs_packed) { + KAI_UNUSED(sr); KAI_ASSUME(lhs != NULL); KAI_ASSUME(lhs_packed != NULL); KAI_ASSUME(m_idx_start == 0); + KAI_ASSUME(mr <= MAX_MR); - const size_t block_height = kai_mr; + const size_t block_height = mr; const size_t row_offset = 0; - const void* in[block_height]; + const void* in[MAX_MR]; for (size_t block_y = 0; block_y < m; block_y += block_height) { const size_t height = KAI_MIN(m - block_y, block_height); - void* out = (char*)lhs_packed + block_y * kai_roundup(k, kr) * sizeof(bfloat16_t); + void* out = (char*)lhs_packed + block_y * kai_roundup(k, kr) * sizeof(uint16_t); size_t width = k; for (size_t y = 0; y < height; y++) { diff --git a/kai/ukernels/matmul/pack/kai_lhs_pack_f32p8x4_bf16_neon.h b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p_f32_neon.h similarity index 79% rename from kai/ukernels/matmul/pack/kai_lhs_pack_f32p8x4_bf16_neon.h rename to kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p_f32_neon.h index d43838bc..e78652d7 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_pack_f32p8x4_bf16_neon.h +++ b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p_f32_neon.h @@ -21,7 +21,7 @@ extern "C" { /// @param[in] mr Number of rows to be interleaved. /// /// @return The m step value. -size_t kai_get_m_step_lhs_pack_f32p8x4_bf16_neon(size_t mr); +size_t kai_get_m_step_lhs_quant_pack_bf16p_f32_neon(size_t mr); /// Gets the offset in bytes to the data element in the LHS buffer. /// @@ -29,7 +29,7 @@ size_t kai_get_m_step_lhs_pack_f32p8x4_bf16_neon(size_t mr); /// @param[in] lhs_stride Row stride in bytes. /// /// @return The offset in bytes to the data element. -size_t kai_get_lhs_offset_lhs_pack_f32p8x4_bf16_neon(size_t m_idx, size_t lhs_stride); +size_t kai_get_lhs_offset_lhs_quant_pack_bf16p_f32_neon(size_t m_idx, size_t lhs_stride, size_t mr); /// Gets the offset in bytes to the data element in the packed LHS buffer. /// @@ -40,7 +40,7 @@ size_t kai_get_lhs_offset_lhs_pack_f32p8x4_bf16_neon(size_t m_idx, size_t lhs_st /// @param[in] sr Unused. Must be 1. /// /// @return The offset in bytes to the data element. -size_t kai_get_lhs_packed_offset_lhs_pack_f32p8x4_bf16_neon(size_t m_idx, size_t k); +size_t kai_get_lhs_packed_offset_lhs_quant_pack_bf16p_f32_neon(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr); /// Gets the size in bytes of the packed LHS buffer. /// @@ -51,15 +51,15 @@ size_t kai_get_lhs_packed_offset_lhs_pack_f32p8x4_bf16_neon(size_t m_idx, size_t /// @param[in] sr Unused. Must be 1. /// /// @return The size in bytes of the packed LHS buffer. -size_t kai_get_lhs_packed_size_lhs_pack_f32p8x4_bf16_neon(size_t m, size_t k, size_t mr, size_t kr, size_t sr); +size_t kai_get_lhs_packed_size_lhs_quant_pack_bf16p_f32_neon(size_t m, size_t k, size_t mr, size_t kr, size_t sr); /// Runs the LHS packing function for matrix multiplication. /// /// The pointer of each buffers (LHS and packed LHS) needs to be added with offset /// calculated using the following functions: /// -/// * LHS: @ref kai_get_lhs_offset_lhs_pack_f32p8x4_bf16_neon. -/// * Packed LHS: @ref kai_get_lhs_packed_offset_lhs_pack_f32p8x4_bf16_neon. +/// * LHS: @ref kai_get_lhs_offset_lhs_quant_pack_bf16p_f32_neon. +/// * Packed LHS: @ref kai_get_lhs_packed_offset_lhs_quant_pack_bf16p_f32_neon. /// /// @param[in] m Number of rows of the unpacked LHS matrix. /// @param[in] k Common dimension between the LHS and RHS matrix. @@ -70,7 +70,7 @@ size_t kai_get_lhs_packed_size_lhs_pack_f32p8x4_bf16_neon(size_t m, size_t k, si /// @param[in] lhs LHS matrix data buffer. /// @param[in] lhs_stride Row stride in bytes of the LHS matrix. /// @param[out] lhs_packed Packed RHS matrix. -void kai_run_lhs_pack_f32p8x4_bf16_neon( +void kai_run_lhs_quant_pack_bf16p_f32_neon( size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride, void* lhs_packed); diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon.c b/kai/ukernels/matmul/pack/kai_rhs_quant_pack_bf16pbiasf32_f32_neon.c similarity index 93% rename from kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon.c rename to kai/ukernels/matmul/pack/kai_rhs_quant_pack_bf16pbiasf32_f32_neon.c index 9e7c7dc9..8a512b21 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon.c +++ b/kai/ukernels/matmul/pack/kai_rhs_quant_pack_bf16pbiasf32_f32_neon.c @@ -8,6 +8,8 @@ #error This file must be compiled for AArch64. #else // Architectural features check. +#define MAX_NR 12 + #include #include #include @@ -16,45 +18,39 @@ #include "kai/kai_common.h" -static const size_t kai_nr = 12; -static const size_t kai_kr = 4; - -size_t kai_get_n_step_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon(void) { - return kai_nr; +size_t kai_get_n_step_rhs_quant_pack_bf16pbiasf32_f32_neon(size_t nr) { + return nr; } -size_t kai_get_rhs_offset_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon(size_t n_idx) { - KAI_ASSUME(n_idx % kai_nr == 0); - +size_t kai_get_rhs_offset_rhs_quant_pack_bf16pbiasf32_f32_neon(size_t n_idx) { return n_idx * sizeof(float); } -size_t kai_get_bias_offset_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon(size_t n_idx) { +size_t kai_get_bias_offset_rhs_quant_pack_bf16pbiasf32_f32_neon(size_t n_idx) { return n_idx * sizeof(uint32_t); } -size_t kai_get_rhs_packed_offset_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon(size_t n_idx, size_t k) { - KAI_ASSUME(n_idx % kai_nr == 0); +size_t kai_get_rhs_packed_offset_rhs_quant_pack_bf16pbiasf32_f32_neon(size_t n_idx, size_t k, size_t nr, size_t kr) { + KAI_ASSUME(n_idx % nr == 0); - return n_idx * (sizeof(uint32_t) + kai_roundup(k, kai_kr) * sizeof(bfloat16_t)); + return n_idx * (sizeof(uint32_t) + kai_roundup(k, kr) * sizeof(uint16_t)); } -size_t kai_get_rhs_packed_size_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon(size_t n, size_t k) { - return kai_get_rhs_packed_offset_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon(kai_roundup(n, kai_nr), k); +size_t kai_get_rhs_packed_size_rhs_quant_pack_bf16pbiasf32_f32_neon(size_t n, size_t k, size_t nr, size_t kr) { + return kai_get_rhs_packed_offset_rhs_quant_pack_bf16pbiasf32_f32_neon(kai_roundup(n, nr), k, nr, kr); } -void kai_run_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon( +void kai_run_rhs_quant_pack_bf16pbiasf32_f32_neon( size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params) { KAI_ASSUME(num_groups == 1); - KAI_ASSUME(nr == kai_nr); - KAI_ASSUME(kr == kai_kr); KAI_ASSUME(sr == 1); KAI_ASSUME(rhs != NULL); KAI_ASSUME(scale == NULL); KAI_ASSUME(rhs_packed != NULL); KAI_ASSUME(extra_bytes == 0); KAI_ASSUME(params == NULL); + KAI_ASSUME(nr <= MAX_NR); size_t height = k; const size_t width = n; @@ -69,17 +65,16 @@ void kai_run_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon( // Fill zeros if bias is nullptr size_t bias_step = nr * sizeof(float); + uint8_t zero_bias[MAX_NR * sizeof(float)]; - void* zero_bias = NULL; if (bias == NULL) { - zero_bias = alloca(bias_step); memset(zero_bias, 0, bias_step); bias_step = 0; } const void* bias_ptr = bias == NULL ? zero_bias : bias; - size_t out_stride = kai_nr * kai_roundup(height, 4) * sizeof(bfloat16_t) + kai_nr * sizeof(uint32_t); + size_t out_stride = nr * kai_roundup(height, kr) * sizeof(uint16_t) + nr * sizeof(uint32_t); __asm__ __volatile__( "mov x22, %x[width]\n" diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon.h b/kai/ukernels/matmul/pack/kai_rhs_quant_pack_bf16pbiasf32_f32_neon.h similarity index 68% rename from kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon.h rename to kai/ukernels/matmul/pack/kai_rhs_quant_pack_bf16pbiasf32_f32_neon.h index 7ef90846..e161caf9 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon.h +++ b/kai/ukernels/matmul/pack/kai_rhs_quant_pack_bf16pbiasf32_f32_neon.h @@ -17,52 +17,56 @@ extern "C" { /// The starting row index must be divisible by `n_step`. /// /// @return The n step value. -size_t kai_get_n_step_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon(void); +size_t kai_get_n_step_rhs_quant_pack_bf16pbiasf32_f32_neon(void); /// Gets the offset in bytes to the data element in the RHS matrix buffer. /// /// @param[in] n_idx Column index. /// /// @return The offset in bytes to the data element. -size_t kai_get_rhs_offset_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon(size_t n_idx); +size_t kai_get_rhs_offset_rhs_quant_pack_bf16pbiasf32_f32_neon(size_t n_idx); /// Gets the offset in bytes to the data element in the bias buffer. /// /// @param[in] n_idx Column index. /// /// @return The offset in bytes to the data element. -size_t kai_get_bias_offset_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon(size_t n_idx); +size_t kai_get_bias_offset_rhs_quant_pack_bf16pbiasf32_f32_neon(size_t n_idx); /// Gets the offset in bytes to the data element in the packed RHS buffer. /// /// @param[in] n_idx Row index. /// @param[in] k Number of columns. +/// @param[in] nr Block size in N dimension. +/// @param[in] kr Block size in K dimension. /// /// @return The offset in bytes to the data element. -size_t kai_get_rhs_packed_offset_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon(size_t n_idx, size_t k); +size_t kai_get_rhs_packed_offset_rhs_quant_pack_bf16pbiasf32_f32_neon(size_t n_idx, size_t k, size_t nr, size_t kr); /// Gets the size in bytes of the packed RHS buffer. /// /// @param[in] n Number of rows. /// @param[in] k Number of columns. +/// @param[in] nr Block size in N dimension. +/// @param[in] kr Block size in K dimension. /// /// @return The size in bytes of the packed RHS buffer. -size_t kai_get_rhs_packed_size_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon(size_t n, size_t k); +size_t kai_get_rhs_packed_size_rhs_quant_pack_bf16pbiasf32_f32_neon(size_t n, size_t k, size_t nr, size_t kr); /// Runs the RHS packing function for matrix multiplication. /// /// The pointer of each buffers (RHS, bias and packed RHS) needs to be added with offset /// calculated using the following functions: /// -/// * RHS: @ref kai_get_rhs_offset_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon. -/// * Bias: @ref kai_get_bias_offset_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon. -/// * Output: @ref kai_get_rhs_packed_offset_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon. +/// * RHS: @ref kai_get_rhs_offset_rhs_quant_pack_bf16pbiasf32_f32_neon. +/// * Bias: @ref kai_get_bias_offset_rhs_quant_pack_bf16pbiasf32_f32_neon. +/// * Output: @ref kai_get_rhs_packed_offset_rhs_quant_pack_bf16pbiasf32_f32_neon. /// /// @param[in] num_groups Number of groups. It must be 1. /// @param[in] n Number of columns of the output matrix. /// @param[in] k Common dimension between the LHS and RHS matrix. -/// @param[in] nr Block size in N dimension. It must be 12. -/// @param[in] kr Block size in K dimension. It must be 4. +/// @param[in] nr Block size in N dimension. +/// @param[in] kr Block size in K dimension. /// @param[in] sr Number of kr splits. It must be 1. /// @param[in] rhs_stride Row stride in bytes of the RHS matrix. /// @param[in] rhs RHS matrix data buffer. @@ -71,7 +75,7 @@ size_t kai_get_rhs_packed_size_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon(size_t /// @param[out] rhs_packed Packed RHS matrix. /// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. It must be 0. /// @param[in] params Extra packing parameters. It must be NULL. -void kai_run_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon( +void kai_run_rhs_quant_pack_bf16pbiasf32_f32_neon( size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params); diff --git a/test/common/matmul_test_common.hpp b/test/common/matmul_test_common.hpp index 90eb95d9..c42096d0 100644 --- a/test/common/matmul_test_common.hpp +++ b/test/common/matmul_test_common.hpp @@ -160,6 +160,16 @@ struct MatMulMethod { /// @return The size in bytes. std::function fn_get_packed_rhs_size; + /// Gets the size in bytes of the packed RHS matrix. + /// + /// @param[in] n Size of the matrix in N dimension. + /// @param[in] k Size of the matrix in K dimension. + /// @param[in] nr Block size in N dimension. + /// @param[in] kr Block size in K dimension. + /// + /// @return The size in bytes. + std::function fn_get_packed_rhs_size_generic_block_size; + /// Gets the offset in bytes of the packed RHS matrix in the RHS packing kernel /// /// @param[in] n_idx Coordinate of the matrix in N dimension. diff --git a/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp index 7b235559..cd1b929d 100644 --- a/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp +++ b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp @@ -33,8 +33,8 @@ // matmul_clamp_f32_bf16p_bf16p #include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.h" -#include "kai/ukernels/matmul/pack/kai_lhs_pack_f32p8x4_bf16_neon.h" -#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon.h" +#include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p_f32_neon.h" +#include "kai/ukernels/matmul/pack/kai_rhs_quant_pack_bf16pbiasf32_f32_neon.h" namespace kai::test { /// List of supported matrix multiplication methods. @@ -65,22 +65,23 @@ const std::array matmul_methods = { .fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, - .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon, + .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_bf16pbiasf32_f32_neon, .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, - .fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_f32p8x4_bf16_neon, - .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_f32p8x4_bf16_neon, - .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_lhs_pack_f32p8x4_bf16_neon, - .fn_pack_lhs = kai_run_lhs_pack_f32p8x4_bf16_neon, + .fn_get_lhs_offset = nullptr, + .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p_f32_neon, + .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + .fn_pack_lhs = kai_run_lhs_quant_pack_bf16p_f32_neon, - .fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon, - .fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon, - .fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon, + .fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_bf16pbiasf32_f32_neon, + .fn_get_packed_rhs_size = nullptr, + .fn_get_packed_rhs_size_generic_block_size = kai_get_rhs_packed_size_rhs_quant_pack_bf16pbiasf32_f32_neon, + .fn_get_pack_rhs_packed_rhs_offset = nullptr, .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, - .fn_pack_rhs = kai_run_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon, + .fn_pack_rhs = kai_run_rhs_quant_pack_bf16pbiasf32_f32_neon, - .fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon, + .fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_bf16pbiasf32_f32_neon, .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, @@ -112,22 +113,23 @@ const std::array matmul_methods = { .fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, - .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon, + .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_bf16pbiasf32_f32_neon, .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, - .fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_f32p8x4_bf16_neon, - .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_f32p8x4_bf16_neon, - .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_lhs_pack_f32p8x4_bf16_neon, - .fn_pack_lhs = kai_run_lhs_pack_f32p8x4_bf16_neon, + .fn_get_lhs_offset = nullptr, + .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p_f32_neon, + .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + .fn_pack_lhs = kai_run_lhs_quant_pack_bf16p_f32_neon, - .fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon, - .fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon, - .fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon, + .fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_bf16pbiasf32_f32_neon, + .fn_get_packed_rhs_size = nullptr, + .fn_get_packed_rhs_size_generic_block_size = kai_get_rhs_packed_size_rhs_quant_pack_bf16pbiasf32_f32_neon, + .fn_get_pack_rhs_packed_rhs_offset = nullptr, .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, - .fn_pack_rhs = kai_run_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon, + .fn_pack_rhs = kai_run_rhs_quant_pack_bf16pbiasf32_f32_neon, - .fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon, + .fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_bf16pbiasf32_f32_neon, .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, @@ -199,8 +201,11 @@ protected: bias = fill_matrix_random(bias_h, bias_w, method.bias_format, 3); } + constexpr size_t nr = 12; + constexpr size_t kr = 4; + std::vector packed_rhs; - packed_rhs.resize(method.fn_get_packed_rhs_size(rhs_w, rhs_h)); + packed_rhs.resize(method.fn_get_packed_rhs_size_generic_block_size(rhs_w, rhs_h, nr, kr)); if (has_rhs_pack) { const auto ref_rhs_row_stride = method.rhs_format.default_row_stride(rhs_w); @@ -278,23 +283,14 @@ TEST_P(MatMulTestBf16, Output) { const uint8_t* lhs_data = nullptr; uintptr_t lhs_offset = 0; - if (method.is_pack_lhs_needed()) { - lhs_data = data.ref_packed_lhs.data(); - - const auto ref_lhs_offset = - method.packed_lhs_format.default_offset_in_bytes(lhs_start_row, lhs_start_col, info.k); - KAI_UNUSED(ref_lhs_offset); + lhs_data = data.ref_packed_lhs.data(); - lhs_offset = method.fn_get_packed_lhs_offset(lhs_start_row, info.k); + const auto ref_lhs_offset = method.packed_lhs_format.default_offset_in_bytes(lhs_start_row, lhs_start_col, info.k); + KAI_UNUSED(ref_lhs_offset); - // TODO: Check with ref_lhs_offset after fixing default_offset_in_bytes() - } else { - lhs_data = data.lhs.data(); + lhs_offset = method.fn_get_packed_lhs_offset(lhs_start_row, info.k); - lhs_offset = method.fn_get_lhs_offset(lhs_start_row, lhs_stride); - const auto ref_lhs_offset = method.lhs_format.default_offset_in_bytes(lhs_start_row, lhs_start_col, lhs_w); - ASSERT_EQ(lhs_offset, ref_lhs_offset); - } + // TODO: Check with ref_lhs_offset after fixing default_offset_in_bytes() const auto rhs_stride = method.rhs_format.default_row_stride(rhs_w); -- GitLab From 19a89ffcb1a7fec5f86e38cb2c5480bdab1f0a53 Mon Sep 17 00:00:00 2001 From: Gunes Bayir Date: Tue, 15 Oct 2024 00:29:52 +0100 Subject: [PATCH 06/10] Improve end-to-end bf16 matmul tests Signed-off-by: Gunes Bayir --- .../pack/kai_lhs_quant_pack_bf16p_f32_neon.h | 2 +- test/common/matmul_test_common.hpp | 7 +- .../matmul_clamp_f32_bf16p_bf16p_test.cpp | 96 ++++++++++--------- 3 files changed, 58 insertions(+), 47 deletions(-) diff --git a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p_f32_neon.h b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p_f32_neon.h index e78652d7..3b002cfd 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p_f32_neon.h +++ b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p_f32_neon.h @@ -29,7 +29,7 @@ size_t kai_get_m_step_lhs_quant_pack_bf16p_f32_neon(size_t mr); /// @param[in] lhs_stride Row stride in bytes. /// /// @return The offset in bytes to the data element. -size_t kai_get_lhs_offset_lhs_quant_pack_bf16p_f32_neon(size_t m_idx, size_t lhs_stride, size_t mr); +size_t kai_get_lhs_offset_lhs_quant_pack_bf16p_f32_neon(size_t m_idx, size_t lhs_stride); /// Gets the offset in bytes to the data element in the packed LHS buffer. /// diff --git a/test/common/matmul_test_common.hpp b/test/common/matmul_test_common.hpp index c42096d0..d7151abf 100644 --- a/test/common/matmul_test_common.hpp +++ b/test/common/matmul_test_common.hpp @@ -30,8 +30,9 @@ struct MatMulShape { struct MatMulMethod { std::string_view name; ///< Name of matmul method. - size_t m0; ///< Block size in M dimension. - size_t n0; ///< Block size in N dimension. + size_t m0{0}; ///< Block size in M dimension. + size_t n0{0}; ///< Block size in N dimension. + size_t k0{0}; ///< Block size in K dimension. bool lhs_transposed; ///< LHS matrix is transposed. bool rhs_transposed; ///< RHS matrix is transposed. @@ -168,7 +169,7 @@ struct MatMulMethod { /// @param[in] kr Block size in K dimension. /// /// @return The size in bytes. - std::function fn_get_packed_rhs_size_generic_block_size; + std::function fn_get_packed_rhs_size_generic_block_size = nullptr; /// Gets the offset in bytes of the packed RHS matrix in the RHS packing kernel /// diff --git a/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp index cd1b929d..a68fc7aa 100644 --- a/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp +++ b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp @@ -45,6 +45,7 @@ const std::array matmul_methods = { .m0 = 8, .n0 = 12, + .k0 = 4, .lhs_transposed = false, .rhs_transposed = false, @@ -68,7 +69,7 @@ const std::array matmul_methods = { .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_bf16pbiasf32_f32_neon, .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, - .fn_get_lhs_offset = nullptr, + .fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p_f32_neon, .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p_f32_neon, .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, .fn_pack_lhs = kai_run_lhs_quant_pack_bf16p_f32_neon, @@ -93,6 +94,7 @@ const std::array matmul_methods = { .m0 = 8, .n0 = 12, + .k0 = 4, .lhs_transposed = false, .rhs_transposed = false, @@ -116,7 +118,7 @@ const std::array matmul_methods = { .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_bf16pbiasf32_f32_neon, .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, - .fn_get_lhs_offset = nullptr, + .fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p_f32_neon, .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p_f32_neon, .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, .fn_pack_lhs = kai_run_lhs_quant_pack_bf16p_f32_neon, @@ -209,7 +211,9 @@ protected: if (has_rhs_pack) { const auto ref_rhs_row_stride = method.rhs_format.default_row_stride(rhs_w); - method.pack_rhs(info.n, info.k, rhs.data(), ref_rhs_row_stride, bias.data(), nullptr, packed_rhs.data()); + method.pack_rhs( + info.n, info.k, rhs.data(), ref_rhs_row_stride, has_bias ? bias.data() : nullptr, nullptr, + packed_rhs.data()); } KAI_ASSUME(method.lhs_format.is_raw()); @@ -217,10 +221,10 @@ protected: KAI_ASSUME(method.dst_format.is_raw()); auto ref_dst = matmul( - lhs.data(), nullptr, nullptr, method.lhs_format.data_type(), // - rhs.data(), rhs_scales.data(), nullptr, method.rhs_format.data_type(), // - bias.data(), nullptr, nullptr, method.bias_format.data_type(), // - method.dst_format.data_type(), // + lhs.data(), nullptr, nullptr, method.lhs_format.data_type(), // + rhs.data(), rhs_scales.data(), nullptr, method.rhs_format.data_type(), // + has_bias ? bias.data() : nullptr, nullptr, nullptr, method.bias_format.data_type(), // + method.dst_format.data_type(), // info.m, info.n, info.k, method.lhs_transposed, method.rhs_transposed); const auto& data = _data[data_id] = { @@ -271,54 +275,61 @@ TEST_P(MatMulTestBf16, Output) { GTEST_SKIP(); } - const auto lhs_w = method.lhs_transposed ? info.m : info.k; - const auto rhs_w = method.rhs_transposed ? info.k : info.n; - const auto bias_w = info.n; - const auto dst_w = info.n; + // ASSERT_FALSE(method.lhs_transposed()); + + const size_t lhs_w = info.k; + const size_t rhs_w = rect.width(); + const size_t bias_w = info.n; + const size_t dst_w = info.n; + const bool has_bias = (data.bias.size() > 0); const auto lhs_start_row = method.lhs_transposed ? 0 : rect.start_row(); - const auto lhs_start_col = method.lhs_transposed ? rect.start_row() : 0; const auto lhs_stride = method.lhs_format.default_row_stride(lhs_w); - const uint8_t* lhs_data = nullptr; - uintptr_t lhs_offset = 0; - - lhs_data = data.ref_packed_lhs.data(); + std::vector lhs_data; + const size_t lhs_packed_size = method.fn_get_packed_lhs_size(info.m, info.k, method.m0, method.k0, 1 /* sr */); + lhs_data.resize(lhs_packed_size); - const auto ref_lhs_offset = method.packed_lhs_format.default_offset_in_bytes(lhs_start_row, lhs_start_col, info.k); - KAI_UNUSED(ref_lhs_offset); + uintptr_t lhs_offset = method.fn_get_lhs_offset(lhs_start_row, lhs_stride); + uintptr_t lhs_packed_offset = method.fn_get_packed_lhs_offset(lhs_start_row, info.k); - lhs_offset = method.fn_get_packed_lhs_offset(lhs_start_row, info.k); + KAI_UNUSED(lhs_offset); + method.fn_pack_lhs( + rect.height(), info.k, method.m0, method.k0, 1 /* sr */, 0 /* m_idx_start */, data.lhs.data() + lhs_offset, + lhs_stride, lhs_data.data() + lhs_packed_offset); - // TODO: Check with ref_lhs_offset after fixing default_offset_in_bytes() + const auto rhs_stride = method.rhs_format.default_row_stride(info.n); - const auto rhs_stride = method.rhs_format.default_row_stride(rhs_w); + std::vector rhs_data; + const size_t rhs_packed_size = + method.fn_get_packed_rhs_size_generic_block_size(info.n, info.k, method.n0, method.k0); + rhs_data.resize(rhs_packed_size); - const uint8_t* rhs_data = nullptr; - uintptr_t rhs_offset = 0; + const auto packed_rhs_start_row = rect.start_col(); + const auto packed_rhs_start_col = 0; - if (method.is_pack_rhs_needed()) { - const auto packed_rhs_start_row = rect.start_col(); - const auto packed_rhs_start_col = 0; + uintptr_t rhs_offset = method.fn_get_rhs_offset(rect.start_col()); + uintptr_t rhs_packed_offset = method.fn_get_main_packed_rhs_offset(packed_rhs_start_row, info.k); + const auto ref_rhs_packed_offset = + method.packed_rhs_format.default_offset_in_bytes(packed_rhs_start_row, packed_rhs_start_col, info.k); - rhs_data = data.ref_packed_rhs.data(); + ASSERT_EQ(rhs_packed_offset, ref_rhs_packed_offset); - rhs_offset = method.fn_get_main_packed_rhs_offset(packed_rhs_start_row, info.k); - const auto ref_rhs_offset = - method.packed_rhs_format.default_offset_in_bytes(packed_rhs_start_row, packed_rhs_start_col, info.k); + uintptr_t bias_offset = sizeof(float) * rect.start_col(); - ASSERT_EQ(rhs_offset, ref_rhs_offset); - } else { - const auto rhs_start_row = method.rhs_transposed ? rect.start_col() : 0; - const auto rhs_start_col = method.rhs_transposed ? 0 : rect.start_col(); + method.fn_pack_rhs( + 1, // num_groups + rhs_w, info.k, method.n0, method.k0, + 1, // sr + rhs_stride, data.rhs.data() + rhs_offset, has_bias ? data.bias.data() + bias_offset : nullptr, + NULL, // Scale + rhs_data.data() + rhs_packed_offset, 0, NULL); - rhs_data = data.rhs.data(); - rhs_offset = method.rhs_format.default_offset_in_bytes(rhs_start_row, rhs_start_col, rhs_w); + if (has_bias) { + const auto ref_bias_offset = method.bias_format.default_offset_in_bytes(0, rect.start_col(), bias_w); + ASSERT_EQ(ref_bias_offset, bias_offset); } - const auto* bias_data = data.bias.data(); - const auto bias_offset = method.bias_format.default_offset_in_bytes(0, rect.start_row(), bias_w); - const auto dst_stride = method.dst_format.default_row_stride(dst_w); const auto dst_offset = method.fn_get_dst_offset(rect.start_row(), rect.start_col(), dst_stride); const auto ref_dst_offset = method.dst_format.default_offset_in_bytes(rect.start_row(), rect.start_col(), dst_w); @@ -330,13 +341,12 @@ TEST_P(MatMulTestBf16, Output) { std::vector dst; dst.resize(dst_size); - method.main_kernel( - rect.height(), rect.width(), info.k, lhs_data + lhs_offset, rhs_data + rhs_offset, bias_data + bias_offset, - dst.data() + dst_offset, lhs_stride, rhs_stride, dst_stride, -std::numeric_limits::infinity(), + rect.height(), rect.width(), info.k, lhs_data.data() + lhs_packed_offset, rhs_data.data() + rhs_packed_offset, + NULL, dst.data() + dst_offset, lhs_stride, rhs_stride, dst_stride, -std::numeric_limits::infinity(), std::numeric_limits::infinity()); - DefaultMismatchHandler handler(0, 0.1, 0, 0.05); + DefaultMismatchHandler handler(0, 0.02, 0, 0.05); const auto success = compare(dst.data(), data.ref_dst.data(), method.dst_format, info.m, info.n, rect, handler); ASSERT_TRUE(success); -- GitLab From f8a505b181dd0c1633de0d84c21031c5cf4a4583 Mon Sep 17 00:00:00 2001 From: Gunes Bayir Date: Thu, 17 Oct 2024 12:13:23 +0100 Subject: [PATCH 07/10] Modifications on names, function argument, and types Following modifications have been made in this commit: - Rhs packing function and matmul has been renamed - Unnecessary function arguments in several functions have been dropped - Pointer types are converted to explicit types instead of using void * if the type is uniform in the tensor - Added recognition of bf16 data type and the kernel in the Readme of the project Signed-off-by: Gunes Bayir --- CMakeLists.txt | 4 +- README.md | 15 ++++ .../CMakeLists.txt | 4 +- .../matmul_clamp_f32_bf16p_bf16p.cpp | 71 +++++++-------- kai/ukernels/matmul/BUILD.bazel | 12 +-- ...bf16p_bf16p12x4biasf32_8x12x4_neon_mmla.c} | 40 ++++----- ...bf16p_bf16p12x4biasf32_8x12x4_neon_mmla.h} | 37 ++++---- ...i_matmul_clamp_f32_bf16p_bf16p_interface.h | 2 +- .../pack/kai_lhs_quant_pack_bf16p_f32_neon.h | 2 +- ...hs_quant_pack_kxn_bf16pbiasf32_f32_neon.c} | 38 ++++---- ...hs_quant_pack_kxn_bf16pbiasf32_f32_neon.h} | 26 +++--- test/common/matmul_test_common.hpp | 12 +-- .../matmul_clamp_f32_bf16p_bf16p_test.cpp | 87 +++++++++++-------- 13 files changed, 178 insertions(+), 172 deletions(-) rename kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/{kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.c => kai_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla.c} (95%) rename kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/{kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.h => kai_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla.h} (76%) rename kai/ukernels/matmul/pack/{kai_rhs_quant_pack_bf16pbiasf32_f32_neon.c => kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.c} (92%) rename kai/ukernels/matmul/pack/{kai_rhs_quant_pack_bf16pbiasf32_f32_neon.h => kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.h} (61%) diff --git a/CMakeLists.txt b/CMakeLists.txt index da64f9af..fda7a92c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -90,8 +90,8 @@ set(KLEIDIAI_FILES_NEON_FP16 set(KLEIDIAI_FILES_NEON_BF16 kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p_f32_neon.c - kai/ukernels/matmul/pack/kai_rhs_quant_pack_bf16pbiasf32_f32_neon.c - kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.c + kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.c + kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla.c ) set(KLEIDIAI_FILES_NEON diff --git a/README.md b/README.md index a139d11b..2d963d48 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,7 @@ Some of the data types currently supported with the KleidiAI library are the fol |---------------------------------------------------------------------------------------------------------------------| ----------- | ----------- | | Floating-point 32-bit | f32 | | | Floating-point 16-bit | f16 | | +| Brain Floating-point 16-bit | bf16 | | | Quantized (q) Symmetric (s) Signed (i) 4-bit (4) Per-Channel (cx) quantization parameters | qsi4cx | An fp32 multiplier shared among all values of the same channel. `x` denotes the entirety of the channel | | Quantized (q) Asymmetric (a) Signed (i) 8-bit (8) Per-Dimension (dx) (for example, Per-Row) quantization parameters | qai8dx | An fp32 multiplier and a int32 zero offset shared among all values of the same dimension. | @@ -177,6 +178,20 @@ Some of the data types currently supported with the KleidiAI library are the fol
+ + Matrix-multiplication with LHS packed and RHS packed matrices + matmul_clamp_f32_bf16p_bf16p + + LHS: bf16p
+ RHS: bf16p
+ DST: f32
+ + + + + The packing function for the RHS and Lhs matrices is listed in the header file of the GEMM micro kernel.
+ +

How to build

diff --git a/examples/matmul_clamp_f32_bf16p_bf16p/CMakeLists.txt b/examples/matmul_clamp_f32_bf16p_bf16p/CMakeLists.txt index c59da140..e1fafe23 100644 --- a/examples/matmul_clamp_f32_bf16p_bf16p/CMakeLists.txt +++ b/examples/matmul_clamp_f32_bf16p_bf16p/CMakeLists.txt @@ -22,9 +22,9 @@ include_directories( # Files requires to build the executable add_executable(matmul_clamp_f32_bf16p_bf16p matmul_clamp_f32_bf16p_bf16p.cpp - ${MATMUL_PATH}/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.c + ${MATMUL_PATH}/kai_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla.c ${MATMUL_PACK_PATH}/kai_lhs_quant_pack_bf16p_f32_neon.c - ${MATMUL_PACK_PATH}/kai_rhs_quant_pack_bf16pbiasf32_f32_neon.c + ${MATMUL_PACK_PATH}/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.c ) target_compile_options(matmul_clamp_f32_bf16p_bf16p diff --git a/examples/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p.cpp b/examples/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p.cpp index 60db9087..d6463410 100644 --- a/examples/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p.cpp +++ b/examples/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p.cpp @@ -29,40 +29,40 @@ // Include micro-kernel variants #include "kai/kai_common.h" #include "kai_lhs_quant_pack_bf16p_f32_neon.h" -#include "kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.h" +#include "kai_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla.h" #include "kai_matmul_clamp_f32_bf16p_bf16p_interface.h" -#include "kai_rhs_quant_pack_bf16pbiasf32_f32_neon.h" +#include "kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.h" -inline float bf16_to_float(const bfloat16_t* v) { - const uint16_t uint_rep = *reinterpret_cast(v); +inline static float bf16_to_float(const uint16_t* v) { + const uint16_t uint_rep = *v; return kai_cast_f32_bf16(uint_rep); } namespace { /// Micro-kernel interface constexpr kai_matmul_clamp_f32_bf16p_bf16p_ukernel ukernel{ - kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, - kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, - kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, - kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, - kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, - kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, - kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, - kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, - kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, - kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, - kai_run_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla}; + kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, + kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, + kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, + kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, + kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, + kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, + kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, + kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, + kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, + kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, + kai_run_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla}; /// @brief Truncate the 32-bit floating point number's least significant 16 mantissa bits /// @param x floating-point number /// @return truncated floating-point number -float truncate(float x) { +inline static float truncate(float x) { uint32_t uval = (*reinterpret_cast(&x) & 0xffff0000); return *reinterpret_cast(&uval); } /// Reference implementation of matrix multiplication -void run_matmul_ref( +static void run_matmul_ref( size_t m, size_t n, size_t k, const float* lhs, const float* rhs, const float* bias, float* dst, float scalar_min, float scalar_max) { for (size_t row_idx = 0; row_idx < m; ++row_idx) { @@ -88,15 +88,6 @@ void fill_matrix(size_t num_rows, size_t num_cols, float* dst, const float weigh } } -void fill_identity(size_t num_rows, size_t num_cols, float* dst, const float weight) { - for (size_t i = 0; i < num_rows * num_cols; i++) { - int col = i % num_cols; - int row = i / num_cols; - - dst[i] = (col == row ? 1.f : 0.f); - } -} - /// Print the matrix void print_matrix(size_t num_rows, size_t num_cols, const char* name, const float* src) { std::cout << name << " = [\n"; @@ -110,7 +101,7 @@ void print_matrix(size_t num_rows, size_t num_cols, const char* name, const floa std::cout << ("]\n\n"); } -void print_matrix(size_t num_rows, size_t num_cols, const char* name, const bfloat16_t* src) { +void print_matrix(size_t num_rows, size_t num_cols, const char* name, const uint16_t* src) { std::cout << name << " = [\n"; for (size_t y = 0; y < num_rows; ++y) { std::cout << " ["; @@ -131,8 +122,8 @@ void print_mixed_prec_matrix( for (size_t x = 0; x < num_cols; ++x) { if (x >= nr) { // print bfloat - const bfloat16_t* src_elm = - reinterpret_cast(src_row + nr * sizeof(float) + (x - nr) * sizeof(bfloat16_t)); + const uint16_t* src_elm = + reinterpret_cast(src_row + nr * sizeof(float) + (x - nr) * sizeof(uint16_t)); std::cout << std::setprecision(2) << std::fixed << bf16_to_float(src_elm) << ", "; } else { // print float @@ -235,7 +226,7 @@ int main() { const size_t sr = ukernel.get_sr(); // In a single row, we pack nr bias values followed by K rows of nr RHS values - const size_t rhs_packed_size = kai_get_rhs_packed_size_rhs_quant_pack_bf16pbiasf32_f32_neon(N, K, nr, kr); + const size_t rhs_packed_size = kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon(N, K, nr, kr); uint8_t* rhs_packed = new uint8_t[rhs_packed_size]; const size_t lhs_stride = K * sizeof(float); @@ -244,17 +235,15 @@ int main() { const size_t dst_stride_col = sizeof(float); const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p_f32_neon(M, K, mr, kr, sr); - bfloat16_t* lhs_packed = new bfloat16_t[lhs_packed_size]; + uint16_t* lhs_packed = new uint16_t[lhs_packed_size]; // Packing only needs to be performed once if the contents of the bias and RHS matrices are expected to be constant. - kai_run_rhs_quant_pack_bf16pbiasf32_f32_neon( - 1, N, K, nr, kr, sr, // Packing arguments - rhs_stride, // RHS stride - rhs, // RHS - bias, // Bias - NULL, // Scale - rhs_packed, // RHS packed - 0, NULL); + kai_run_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon( + N, K, nr, kr, sr, // Packing arguments + rhs_stride, // RHS stride + rhs, // RHS + bias, // Bias + rhs_packed); // RHS packed // The RHS and Bias buffers can be freed after packing, however we reuse them for the reference test below @@ -262,7 +251,7 @@ int main() { const size_t rhs_packed_cols = nr + kai_roundup(K, kr) * nr; // Each col has nr floats and then K*nr bfloats - int rhs_packed_stride = nr * sizeof(float) + kai_roundup(K, kr) * nr * sizeof(bfloat16_t); + int rhs_packed_stride = nr * sizeof(float) + kai_roundup(K, kr) * nr * sizeof(uint16_t); const size_t rhs_packed_rows = rhs_packed_size / rhs_packed_stride; print_mixed_prec_matrix(rhs_packed_rows, rhs_packed_cols, "rhs_packed", rhs_packed, nr, rhs_packed_stride); @@ -301,7 +290,7 @@ int main() { const bool is_valid = is_output_correct(M, N, rel_tolerance, dst_ref, dst); std::cout << "TEST[matmul_clamp_f32_bf16p_bf16p]\n"; - std::cout << "- ukernel: matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla\n"; + std::cout << "- ukernel: matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla\n"; if (is_valid) { std::cout << "- Status: PASSED\n"; std::cout << "- Performance: " << time_matmul.count() << "ns\n"; diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 8c7d11fa..5a0d0121 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -41,8 +41,8 @@ kai_c_library( kai_c_library( name = "clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla", - srcs = ["matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.c"], - hdrs = ["matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.h"], + srcs = ["matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla.c"], + hdrs = ["matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla.h"], cpu_uarch = kai_cpu_bf16(), deps = [ ":clamp_f32_bf16p_bf16p_interface", @@ -191,9 +191,9 @@ kai_c_library( ) kai_c_library( - name = "rhs_quant_pack_bf16pbiasf32_f32_neon", - srcs = ["pack/kai_rhs_quant_pack_bf16pbiasf32_f32_neon.c"], - hdrs = ["pack/kai_rhs_quant_pack_bf16pbiasf32_f32_neon.h"], + name = "rhs_quant_pack_kxn_bf16pbiasf32_f32_neon", + srcs = ["pack/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.c"], + hdrs = ["pack/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.h"], cpu_uarch = kai_cpu_bf16(), ) @@ -373,6 +373,6 @@ kai_c_library( ":rhs_pack_nxk_qsi4c32p_qsu4c32s1s0", ":rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0", ":rhs_pack_nxk_qsi4cxp_qs4cxs1s0", - ":rhs_quant_pack_bf16pbiasf32_f32_neon", + ":rhs_quant_pack_kxn_bf16pbiasf32_f32_neon", ], ) diff --git a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.c b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla.c similarity index 95% rename from kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.c rename to kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla.c index 7132d15b..40495f9d 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla.c @@ -8,58 +8,56 @@ #error This file must be compiled for AArch64, FEAT_BF16. #else // Architectural features check. +#include "kai_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla.h" + #include #include #include -#include - -typedef bfloat16_t bfloat16; #include "kai/kai_common.h" -#include "kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.h" static const size_t kai_mr = 8; static const size_t kai_nr = 12; static const size_t kai_kr = 4; static const size_t kai_sr = 1; -size_t kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(void) { +size_t kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla(void) { return kai_mr; } -size_t kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(void) { +size_t kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla(void) { return kai_nr; } -size_t kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(void) { +size_t kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla(void) { return kai_mr; } -size_t kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(void) { +size_t kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla(void) { return kai_nr; } -size_t kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(void) { +size_t kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla(void) { return kai_kr; } -size_t kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(void) { +size_t kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla(void) { return kai_sr; } -size_t kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(size_t m_idx, size_t k) { +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla(size_t m_idx, size_t k) { KAI_ASSUME(m_idx % kai_mr == 0); - return m_idx * kai_roundup(k, kai_kr) * sizeof(bfloat16_t); + return m_idx * kai_roundup(k, kai_kr) * sizeof(uint16_t); } -size_t kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(size_t n_idx, size_t k) { +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla(size_t n_idx, size_t k) { KAI_ASSUME(n_idx % kai_nr == 0); - return n_idx * (sizeof(float) + kai_roundup(k, kai_kr) * sizeof(bfloat16_t)); + return n_idx * (sizeof(float) + kai_roundup(k, kai_kr) * sizeof(uint16_t)); } -size_t kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla( +size_t kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla( size_t m_idx, size_t n_idx, size_t stride) { KAI_ASSUME(m_idx % kai_mr == 0); KAI_ASSUME(n_idx % kai_nr == 0); @@ -67,15 +65,15 @@ size_t kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mm return m_idx * stride + n_idx * sizeof(float); } -size_t kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(size_t m, size_t n) { +size_t kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla(size_t m, size_t n) { return m * n * sizeof(float); } -void kai_run_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla( - size_t m, size_t n, size_t k, // - const void* lhs_packed, // - const void* rhs_packed, // - void* dst, size_t dst_stride_row, size_t dst_stride_col, // +void kai_run_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla( + size_t m, size_t n, size_t k, // + const uint16_t* lhs_packed, // + const void* rhs_packed, // + float* dst, size_t dst_stride_row, size_t dst_stride_col, // float clamp_min, float clamp_max) { KAI_ASSERT(dst_stride_col == sizeof(float)); diff --git a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.h b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla.h similarity index 76% rename from kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.h rename to kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla.h index 316933ed..03c8bb52 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla.h @@ -11,6 +11,7 @@ #else // Architectural features check. #include +#include #ifdef __cplusplus extern "C" { @@ -23,42 +24,42 @@ extern "C" { /// The starting row index must be divisible by `m_step`. /// /// @return The m step value. -size_t kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(void); +size_t kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla(void); /// Gets n step value. /// /// The starting column index must be divisible by `n_step`. /// /// @return The n step value. -size_t kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(void); +size_t kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla(void); /// Gets mr value. /// /// This is the packing parameter which must be used to pack the LHS matrix. /// /// @return The mr value. -size_t kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(void); +size_t kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla(void); /// Gets nr value. /// /// This is the packing parameter which must be used to pack the RHS matrix. /// /// @return The nr value. -size_t kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(void); +size_t kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla(void); /// Gets kr value. /// /// This is the packing parameter which must be used to pack the LHS & RHS matrices. /// /// @return The kr value. -size_t kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(void); +size_t kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla(void); /// Gets sr value. /// /// This is the packing parameter which must be used to pack the RHS matrix. /// /// @return The sr value. -size_t kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(void); +size_t kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla(void); /// Gets the offset in bytes to the data element in the packed LHS matrix buffer. /// @@ -66,7 +67,7 @@ size_t kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(void) /// @param[in] k Number of columns. /// /// @return The offset in bytes to the data element. -size_t kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(size_t m_idx, size_t k); +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla(size_t m_idx, size_t k); /// Gets the offset in bytes to the data element in the packed RHS matrix buffer. /// @@ -74,7 +75,7 @@ size_t kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_ /// @param[in] k Number of columns. /// /// @return The offset in bytes to the data element. -size_t kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(size_t n_idx, size_t k); +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla(size_t n_idx, size_t k); /// Gets the offset in bytes to the data element in the destination matrix buffer. /// @@ -83,7 +84,7 @@ size_t kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_ /// @param[in] stride Row stride in bytes. /// /// @return The offset in bytes to the data element. -size_t kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla( +size_t kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla( size_t m_idx, size_t n_idx, size_t stride); /// Gets the size in bytes of the destination matrix buffer. @@ -92,16 +93,16 @@ size_t kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mm /// @param[in] n Number of columns. /// /// @return The size in bytes of the destination matrix buffer. -size_t kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(size_t m, size_t n); +size_t kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla(size_t m, size_t n); /// Runs the matrix multiplication microkernel followed by a clamp operation. /// /// The pointer of each buffers (packed LHS, packed RHS and output) needs to be added with offset /// calculated using the following functions: /// -/// * Packed LHS: @ref kai_get_lhs_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla. -/// * Packed RHS: @ref kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla. -/// * Output: @ref kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla. +/// * Packed LHS: @ref kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla. +/// * Packed RHS: @ref kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla. +/// * Output: @ref kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla. /// /// @param[in] m Number of output rows to be computed. /// @param[in] n Number of output columns to be computed. @@ -113,11 +114,11 @@ size_t kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla /// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(float) /// @param[in] clamp_min Minimum value to clamp the final result. /// @param[in] clamp_max Maximum value to clamp the final result. -void kai_run_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla( - size_t m, size_t n, size_t k, // - const void* lhs_packed, // - const void* rhs_packed, // - void* dst, size_t dst_stride_row, size_t dst_stride_col, // +void kai_run_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla( + size_t m, size_t n, size_t k, // + const uint16_t* lhs_packed, // + const void* rhs_packed, // + float* dst, size_t dst_stride_row, size_t dst_stride_col, // float clamp_min, float clamp_max); #ifdef __cplusplus diff --git a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h index 768f63f5..37d5fcc9 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h @@ -32,7 +32,7 @@ typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_dst_size_func_t)(size_t m, /// Micro-kernel core function ("run" method) typedef void (*kai_matmul_clamp_f32_bf16p_bf16p_run_matmul_func_t)( - size_t m, size_t n, size_t k, const void* lhs, const void* rhs_packed, void* dst, size_t dst_stride_row, + size_t m, size_t n, size_t k, const uint16_t* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max); /// Micro-kernel interface diff --git a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p_f32_neon.h b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p_f32_neon.h index 3b002cfd..200a66c4 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p_f32_neon.h +++ b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p_f32_neon.h @@ -69,7 +69,7 @@ size_t kai_get_lhs_packed_size_lhs_quant_pack_bf16p_f32_neon(size_t m, size_t k, /// @param[in] m_idx_start Unused. Must be 0. /// @param[in] lhs LHS matrix data buffer. /// @param[in] lhs_stride Row stride in bytes of the LHS matrix. -/// @param[out] lhs_packed Packed RHS matrix. +/// @param[out] lhs_packed Packed LHS matrix. void kai_run_lhs_quant_pack_bf16p_f32_neon( size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride, void* lhs_packed); diff --git a/kai/ukernels/matmul/pack/kai_rhs_quant_pack_bf16pbiasf32_f32_neon.c b/kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.c similarity index 92% rename from kai/ukernels/matmul/pack/kai_rhs_quant_pack_bf16pbiasf32_f32_neon.c rename to kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.c index 8a512b21..6f778e2a 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_quant_pack_bf16pbiasf32_f32_neon.c +++ b/kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.c @@ -10,7 +10,6 @@ #define MAX_NR 12 -#include #include #include #include @@ -18,63 +17,56 @@ #include "kai/kai_common.h" -size_t kai_get_n_step_rhs_quant_pack_bf16pbiasf32_f32_neon(size_t nr) { +size_t kai_get_n_step_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon(size_t nr) { return nr; } -size_t kai_get_rhs_offset_rhs_quant_pack_bf16pbiasf32_f32_neon(size_t n_idx) { +size_t kai_get_rhs_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon(size_t n_idx) { return n_idx * sizeof(float); } -size_t kai_get_bias_offset_rhs_quant_pack_bf16pbiasf32_f32_neon(size_t n_idx) { +size_t kai_get_bias_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon(size_t n_idx) { return n_idx * sizeof(uint32_t); } -size_t kai_get_rhs_packed_offset_rhs_quant_pack_bf16pbiasf32_f32_neon(size_t n_idx, size_t k, size_t nr, size_t kr) { +size_t kai_get_rhs_packed_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon( + size_t n_idx, size_t k, size_t nr, size_t kr) { KAI_ASSUME(n_idx % nr == 0); return n_idx * (sizeof(uint32_t) + kai_roundup(k, kr) * sizeof(uint16_t)); } -size_t kai_get_rhs_packed_size_rhs_quant_pack_bf16pbiasf32_f32_neon(size_t n, size_t k, size_t nr, size_t kr) { - return kai_get_rhs_packed_offset_rhs_quant_pack_bf16pbiasf32_f32_neon(kai_roundup(n, nr), k, nr, kr); +size_t kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon(size_t n, size_t k, size_t nr, size_t kr) { + return kai_get_rhs_packed_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon(kai_roundup(n, nr), k, nr, kr); } -void kai_run_rhs_quant_pack_bf16pbiasf32_f32_neon( - size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, - const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params) { - KAI_ASSUME(num_groups == 1); +void kai_run_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon( + size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const float* rhs, const float* bias, + void* rhs_packed) { KAI_ASSUME(sr == 1); KAI_ASSUME(rhs != NULL); - KAI_ASSUME(scale == NULL); KAI_ASSUME(rhs_packed != NULL); - KAI_ASSUME(extra_bytes == 0); - KAI_ASSUME(params == NULL); KAI_ASSUME(nr <= MAX_NR); size_t height = k; const size_t width = n; - const void* in = rhs; + const void* in = (void*)rhs; void* out = rhs_packed; const size_t in_stride = rhs_stride; - float* pad_row = (float*)alloca(width * sizeof(float)); - - if (height % 4) { - memset(pad_row, 0, width * sizeof(float)); - } + const float* pad_row = rhs; // Fill zeros if bias is nullptr size_t bias_step = nr * sizeof(float); uint8_t zero_bias[MAX_NR * sizeof(float)]; if (bias == NULL) { - memset(zero_bias, 0, bias_step); + memset(zero_bias, 0, MAX_NR * sizeof(float)); bias_step = 0; } - const void* bias_ptr = bias == NULL ? zero_bias : bias; + const void* bias_ptr = bias == NULL ? (void*)zero_bias : (void*)bias; - size_t out_stride = nr * kai_roundup(height, kr) * sizeof(uint16_t) + nr * sizeof(uint32_t); + const size_t out_stride = nr * kai_roundup(height, kr) * sizeof(uint16_t) + nr * sizeof(uint32_t); __asm__ __volatile__( "mov x22, %x[width]\n" diff --git a/kai/ukernels/matmul/pack/kai_rhs_quant_pack_bf16pbiasf32_f32_neon.h b/kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.h similarity index 61% rename from kai/ukernels/matmul/pack/kai_rhs_quant_pack_bf16pbiasf32_f32_neon.h rename to kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.h index e161caf9..7e4e200d 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_quant_pack_bf16pbiasf32_f32_neon.h +++ b/kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.h @@ -17,21 +17,21 @@ extern "C" { /// The starting row index must be divisible by `n_step`. /// /// @return The n step value. -size_t kai_get_n_step_rhs_quant_pack_bf16pbiasf32_f32_neon(void); +size_t kai_get_n_step_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon(void); /// Gets the offset in bytes to the data element in the RHS matrix buffer. /// /// @param[in] n_idx Column index. /// /// @return The offset in bytes to the data element. -size_t kai_get_rhs_offset_rhs_quant_pack_bf16pbiasf32_f32_neon(size_t n_idx); +size_t kai_get_rhs_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon(size_t n_idx); /// Gets the offset in bytes to the data element in the bias buffer. /// /// @param[in] n_idx Column index. /// /// @return The offset in bytes to the data element. -size_t kai_get_bias_offset_rhs_quant_pack_bf16pbiasf32_f32_neon(size_t n_idx); +size_t kai_get_bias_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon(size_t n_idx); /// Gets the offset in bytes to the data element in the packed RHS buffer. /// @@ -41,7 +41,7 @@ size_t kai_get_bias_offset_rhs_quant_pack_bf16pbiasf32_f32_neon(size_t n_idx); /// @param[in] kr Block size in K dimension. /// /// @return The offset in bytes to the data element. -size_t kai_get_rhs_packed_offset_rhs_quant_pack_bf16pbiasf32_f32_neon(size_t n_idx, size_t k, size_t nr, size_t kr); +size_t kai_get_rhs_packed_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon(size_t n_idx, size_t k, size_t nr, size_t kr); /// Gets the size in bytes of the packed RHS buffer. /// @@ -51,18 +51,17 @@ size_t kai_get_rhs_packed_offset_rhs_quant_pack_bf16pbiasf32_f32_neon(size_t n_i /// @param[in] kr Block size in K dimension. /// /// @return The size in bytes of the packed RHS buffer. -size_t kai_get_rhs_packed_size_rhs_quant_pack_bf16pbiasf32_f32_neon(size_t n, size_t k, size_t nr, size_t kr); +size_t kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon(size_t n, size_t k, size_t nr, size_t kr); /// Runs the RHS packing function for matrix multiplication. /// /// The pointer of each buffers (RHS, bias and packed RHS) needs to be added with offset /// calculated using the following functions: /// -/// * RHS: @ref kai_get_rhs_offset_rhs_quant_pack_bf16pbiasf32_f32_neon. -/// * Bias: @ref kai_get_bias_offset_rhs_quant_pack_bf16pbiasf32_f32_neon. -/// * Output: @ref kai_get_rhs_packed_offset_rhs_quant_pack_bf16pbiasf32_f32_neon. +/// * RHS: @ref kai_get_rhs_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon. +/// * Bias: @ref kai_get_bias_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon. +/// * Output: @ref kai_get_rhs_packed_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon. /// -/// @param[in] num_groups Number of groups. It must be 1. /// @param[in] n Number of columns of the output matrix. /// @param[in] k Common dimension between the LHS and RHS matrix. /// @param[in] nr Block size in N dimension. @@ -71,13 +70,10 @@ size_t kai_get_rhs_packed_size_rhs_quant_pack_bf16pbiasf32_f32_neon(size_t n, si /// @param[in] rhs_stride Row stride in bytes of the RHS matrix. /// @param[in] rhs RHS matrix data buffer. /// @param[in] bias Bias matrix data buffer. -/// @param[in] scale Scale data buffer. It must be NULL. /// @param[out] rhs_packed Packed RHS matrix. -/// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. It must be 0. -/// @param[in] params Extra packing parameters. It must be NULL. -void kai_run_rhs_quant_pack_bf16pbiasf32_f32_neon( - size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, - const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params); +void kai_run_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon( + size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const float* rhs, const float* bias, + void* rhs_packed); #ifdef __cplusplus } // extern "C" diff --git a/test/common/matmul_test_common.hpp b/test/common/matmul_test_common.hpp index d7151abf..6722ff84 100644 --- a/test/common/matmul_test_common.hpp +++ b/test/common/matmul_test_common.hpp @@ -260,10 +260,10 @@ struct MatMulMethod { /// @param[in] clamp_min Lower bound of the output data. /// @param[in] clamp_max Upper bound of the output data. std::function fn_matmul_f32_bf16p_bf16p = nullptr; @@ -340,7 +340,9 @@ struct MatMulMethod { } else if (fn_matmul_f32_f32p_f32p) { fn_matmul_f32_f32p_f32p(m, n, k, lhs, rhs, dst, dst_stride, sizeof(float), clamp_min, clamp_max); } else if (fn_matmul_f32_bf16p_bf16p) { - fn_matmul_f32_bf16p_bf16p(m, n, k, lhs, rhs, dst, dst_stride, sizeof(float), clamp_min, clamp_max); + fn_matmul_f32_bf16p_bf16p( + m, n, k, reinterpret_cast(lhs), rhs, reinterpret_cast(dst), dst_stride, + sizeof(float), clamp_min, clamp_max); } else { KAI_ERROR("Main kernel is not available!"); } diff --git a/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp index a68fc7aa..42b3ef1a 100644 --- a/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp +++ b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp @@ -32,13 +32,28 @@ #include "test/reference/pack.hpp" // matmul_clamp_f32_bf16p_bf16p -#include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla.h" #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p_f32_neon.h" -#include "kai/ukernels/matmul/pack/kai_rhs_quant_pack_bf16pbiasf32_f32_neon.h" +#include "kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.h" namespace kai::test { /// List of supported matrix multiplication methods. namespace { + +/// Adapters for using packing and matmul functions with the unified interface of the test framework +inline void kai_run_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon_adapter( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, + const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params) { + KAI_UNUSED(num_groups); + KAI_UNUSED(scale); + KAI_UNUSED(extra_bytes); + KAI_UNUSED(params); + + kai_run_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon( + n, k, nr, kr, sr, rhs_stride, reinterpret_cast(rhs), reinterpret_cast(bias), + rhs_packed); +} + const std::array matmul_methods = { MatMulMethod{ .name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla", @@ -60,34 +75,34 @@ const std::array matmul_methods = { .bias_format = DataFormat(DataType::FP32), .fn_is_supported = cpu_has_bf16, - .fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, - .fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, - .fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, - .fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + .fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, + .fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, + .fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, + .fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, - .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, - .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_bf16pbiasf32_f32_neon, - .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, + .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon, + .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, .fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p_f32_neon, .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p_f32_neon, - .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, .fn_pack_lhs = kai_run_lhs_quant_pack_bf16p_f32_neon, - .fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_bf16pbiasf32_f32_neon, + .fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon, .fn_get_packed_rhs_size = nullptr, - .fn_get_packed_rhs_size_generic_block_size = kai_get_rhs_packed_size_rhs_quant_pack_bf16pbiasf32_f32_neon, + .fn_get_packed_rhs_size_generic_block_size = kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon, .fn_get_pack_rhs_packed_rhs_offset = nullptr, .fn_get_main_packed_rhs_offset = - kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, - .fn_pack_rhs = kai_run_rhs_quant_pack_bf16pbiasf32_f32_neon, + kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, + .fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon_adapter, - .fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_bf16pbiasf32_f32_neon, + .fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon, - .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, - .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, + .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, - .fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + .fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, }, MatMulMethod{ .name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla_opt_bias", @@ -109,36 +124,36 @@ const std::array matmul_methods = { .bias_format = DataFormat(DataType::UNKNOWN), .fn_is_supported = cpu_has_bf16, - .fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, - .fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, - .fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, - .fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + .fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, + .fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, + .fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, + .fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, - .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, - .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_bf16pbiasf32_f32_neon, - .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, + .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon, + .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, .fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p_f32_neon, .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p_f32_neon, - .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, .fn_pack_lhs = kai_run_lhs_quant_pack_bf16p_f32_neon, - .fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_bf16pbiasf32_f32_neon, + .fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon, .fn_get_packed_rhs_size = nullptr, - .fn_get_packed_rhs_size_generic_block_size = kai_get_rhs_packed_size_rhs_quant_pack_bf16pbiasf32_f32_neon, + .fn_get_packed_rhs_size_generic_block_size = kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon, .fn_get_pack_rhs_packed_rhs_offset = nullptr, .fn_get_main_packed_rhs_offset = - kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, - .fn_pack_rhs = kai_run_rhs_quant_pack_bf16pbiasf32_f32_neon, + kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, + .fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon_adapter, - .fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_bf16pbiasf32_f32_neon, + .fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon, - .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, - .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, + .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, - .fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + .fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, }}; -} +} // namespace /// Matrix multiplication test fixture. class MatMulTestBf16 : public testing::TestWithParam { @@ -275,8 +290,6 @@ TEST_P(MatMulTestBf16, Output) { GTEST_SKIP(); } - // ASSERT_FALSE(method.lhs_transposed()); - const size_t lhs_w = info.k; const size_t rhs_w = rect.width(); const size_t bias_w = info.n; -- GitLab From 6cb4cc6f342d29e28923aeea74a2ccb737cddbfc Mon Sep 17 00:00:00 2001 From: Gunes Bayir Date: Thu, 17 Oct 2024 15:09:08 +0100 Subject: [PATCH 08/10] Revert back interface changes to the matmul and packing functions Signed-off-by: Gunes Bayir --- .../matmul_clamp_f32_bf16p_bf16p.cpp | 12 +++++++----- ..._bf16p_bf16p12x4biasf32_8x12x4_neon_mmla.c | 8 ++++---- ..._bf16p_bf16p12x4biasf32_8x12x4_neon_mmla.h | 8 ++++---- ...i_matmul_clamp_f32_bf16p_bf16p_interface.h | 2 +- ...rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.c | 8 ++++++-- ...rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.h | 8 ++++++-- test/common/matmul_test_common.hpp | 8 ++++---- .../matmul_clamp_f32_bf16p_bf16p_test.cpp | 19 ++----------------- 8 files changed, 34 insertions(+), 39 deletions(-) diff --git a/examples/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p.cpp b/examples/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p.cpp index d6463410..0f524a54 100644 --- a/examples/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p.cpp +++ b/examples/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p.cpp @@ -239,11 +239,13 @@ int main() { // Packing only needs to be performed once if the contents of the bias and RHS matrices are expected to be constant. kai_run_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon( - N, K, nr, kr, sr, // Packing arguments - rhs_stride, // RHS stride - rhs, // RHS - bias, // Bias - rhs_packed); // RHS packed + 1, N, K, nr, kr, sr, // Packing arguments + rhs_stride, // RHS stride + rhs, // RHS + bias, // Bias + NULL, // Scale + rhs_packed, // RHS packed + 0, NULL); // The RHS and Bias buffers can be freed after packing, however we reuse them for the reference test below diff --git a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla.c b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla.c index 40495f9d..13c2ea67 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla.c @@ -70,10 +70,10 @@ size_t kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla } void kai_run_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla( - size_t m, size_t n, size_t k, // - const uint16_t* lhs_packed, // - const void* rhs_packed, // - float* dst, size_t dst_stride_row, size_t dst_stride_col, // + size_t m, size_t n, size_t k, // + const void* lhs_packed, // + const void* rhs_packed, // + void* dst, size_t dst_stride_row, size_t dst_stride_col, // float clamp_min, float clamp_max) { KAI_ASSERT(dst_stride_col == sizeof(float)); diff --git a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla.h b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla.h index 03c8bb52..812a289f 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla.h @@ -115,10 +115,10 @@ size_t kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla /// @param[in] clamp_min Minimum value to clamp the final result. /// @param[in] clamp_max Maximum value to clamp the final result. void kai_run_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla( - size_t m, size_t n, size_t k, // - const uint16_t* lhs_packed, // - const void* rhs_packed, // - float* dst, size_t dst_stride_row, size_t dst_stride_col, // + size_t m, size_t n, size_t k, // + const void* lhs_packed, // + const void* rhs_packed, // + void* dst, size_t dst_stride_row, size_t dst_stride_col, // float clamp_min, float clamp_max); #ifdef __cplusplus diff --git a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h index 37d5fcc9..62f89279 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h @@ -32,7 +32,7 @@ typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_dst_size_func_t)(size_t m, /// Micro-kernel core function ("run" method) typedef void (*kai_matmul_clamp_f32_bf16p_bf16p_run_matmul_func_t)( - size_t m, size_t n, size_t k, const uint16_t* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, + size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max); /// Micro-kernel interface diff --git a/kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.c b/kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.c index 6f778e2a..bf9eb107 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.c +++ b/kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.c @@ -41,11 +41,15 @@ size_t kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon(size_t n } void kai_run_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon( - size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const float* rhs, const float* bias, - void* rhs_packed) { + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, + const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params) { + KAI_ASSUME(num_groups == 1); KAI_ASSUME(sr == 1); KAI_ASSUME(rhs != NULL); + KAI_ASSUME(scale == NULL); KAI_ASSUME(rhs_packed != NULL); + KAI_ASSUME(extra_bytes == 0); + KAI_ASSUME(params == NULL); KAI_ASSUME(nr <= MAX_NR); size_t height = k; diff --git a/kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.h b/kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.h index 7e4e200d..f786c7a7 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.h +++ b/kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.h @@ -62,6 +62,7 @@ size_t kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon(size_t n /// * Bias: @ref kai_get_bias_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon. /// * Output: @ref kai_get_rhs_packed_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon. /// +/// @param[in] num_groups Number of groups. It must be 1. /// @param[in] n Number of columns of the output matrix. /// @param[in] k Common dimension between the LHS and RHS matrix. /// @param[in] nr Block size in N dimension. @@ -70,10 +71,13 @@ size_t kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon(size_t n /// @param[in] rhs_stride Row stride in bytes of the RHS matrix. /// @param[in] rhs RHS matrix data buffer. /// @param[in] bias Bias matrix data buffer. +/// @param[in] scale Scale data buffer. It must be NULL. /// @param[out] rhs_packed Packed RHS matrix. +/// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. It must be 0. +/// @param[in] params Extra packing parameters. It must be NULL. void kai_run_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon( - size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const float* rhs, const float* bias, - void* rhs_packed); + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, + const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params); #ifdef __cplusplus } // extern "C" diff --git a/test/common/matmul_test_common.hpp b/test/common/matmul_test_common.hpp index 6722ff84..21f6e244 100644 --- a/test/common/matmul_test_common.hpp +++ b/test/common/matmul_test_common.hpp @@ -260,10 +260,10 @@ struct MatMulMethod { /// @param[in] clamp_min Lower bound of the output data. /// @param[in] clamp_max Upper bound of the output data. std::function fn_matmul_f32_bf16p_bf16p = nullptr; diff --git a/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp index 42b3ef1a..23690cce 100644 --- a/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp +++ b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp @@ -39,21 +39,6 @@ namespace kai::test { /// List of supported matrix multiplication methods. namespace { - -/// Adapters for using packing and matmul functions with the unified interface of the test framework -inline void kai_run_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon_adapter( - size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, - const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params) { - KAI_UNUSED(num_groups); - KAI_UNUSED(scale); - KAI_UNUSED(extra_bytes); - KAI_UNUSED(params); - - kai_run_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon( - n, k, nr, kr, sr, rhs_stride, reinterpret_cast(rhs), reinterpret_cast(bias), - rhs_packed); -} - const std::array matmul_methods = { MatMulMethod{ .name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla", @@ -95,7 +80,7 @@ const std::array matmul_methods = { .fn_get_pack_rhs_packed_rhs_offset = nullptr, .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, - .fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon_adapter, + .fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon, .fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon, @@ -144,7 +129,7 @@ const std::array matmul_methods = { .fn_get_pack_rhs_packed_rhs_offset = nullptr, .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, - .fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon_adapter, + .fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon, .fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon, -- GitLab From d852653a845418167dc9228f21cdc39daa0bf806 Mon Sep 17 00:00:00 2001 From: Gunes Bayir Date: Thu, 17 Oct 2024 15:22:22 +0100 Subject: [PATCH 09/10] Bazel build issue fix Signed-off-by: Gunes Bayir --- kai/ukernels/matmul/BUILD.bazel | 2 -- 1 file changed, 2 deletions(-) diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 5a0d0121..ed3df38f 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -365,8 +365,6 @@ kai_c_library( ":rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon", ":rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme", ":rhs_pack_kxn_f32pb_f32_f32_16vlx1_sme", - ":rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon", - "rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon", ":rhs_pack_kxn_f32pbiasf32_f32_f32_neon", ":rhs_pack_kxn_qsi4c32p_qsu4c32s1s0", ":rhs_pack_kxn_qsi4cxp_qs4cxs1s0", -- GitLab From cf2f3e3842cae4489fef888a49ccf4f330f33a23 Mon Sep 17 00:00:00 2001 From: Gunes Bayir Date: Sun, 20 Oct 2024 15:03:33 +0100 Subject: [PATCH 10/10] Change kernel name to adapt to implicit naming rules Signed-off-by: Gunes Bayir --- CMakeLists.txt | 2 +- .../CMakeLists.txt | 4 +- .../matmul_clamp_f32_bf16p_bf16p.cpp | 32 +++++++------ kai/ukernels/matmul/BUILD.bazel | 4 +- ...p_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.c} | 24 +++++----- ...p_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.h} | 29 ++++++----- .../matmul_clamp_f32_bf16p_bf16p_test.cpp | 48 +++++++++---------- 7 files changed, 71 insertions(+), 72 deletions(-) rename kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/{kai_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla.c => kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.c} (95%) rename kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/{kai_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla.h => kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.h} (78%) diff --git a/CMakeLists.txt b/CMakeLists.txt index fda7a92c..4d846759 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -91,7 +91,7 @@ set(KLEIDIAI_FILES_NEON_FP16 set(KLEIDIAI_FILES_NEON_BF16 kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p_f32_neon.c kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.c - kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla.c + kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.c ) set(KLEIDIAI_FILES_NEON diff --git a/examples/matmul_clamp_f32_bf16p_bf16p/CMakeLists.txt b/examples/matmul_clamp_f32_bf16p_bf16p/CMakeLists.txt index e1fafe23..4b13a183 100644 --- a/examples/matmul_clamp_f32_bf16p_bf16p/CMakeLists.txt +++ b/examples/matmul_clamp_f32_bf16p_bf16p/CMakeLists.txt @@ -4,7 +4,7 @@ # SPDX-License-Identifier: Apache-2.0 # -cmake_minimum_required(VERSION 3.18) +cmake_minimum_required(VERSION 3.16) project(KleidiAI) @@ -22,7 +22,7 @@ include_directories( # Files requires to build the executable add_executable(matmul_clamp_f32_bf16p_bf16p matmul_clamp_f32_bf16p_bf16p.cpp - ${MATMUL_PATH}/kai_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla.c + ${MATMUL_PATH}/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.c ${MATMUL_PACK_PATH}/kai_lhs_quant_pack_bf16p_f32_neon.c ${MATMUL_PACK_PATH}/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.c ) diff --git a/examples/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p.cpp b/examples/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p.cpp index 0f524a54..f11900c3 100644 --- a/examples/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p.cpp +++ b/examples/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p.cpp @@ -29,7 +29,7 @@ // Include micro-kernel variants #include "kai/kai_common.h" #include "kai_lhs_quant_pack_bf16p_f32_neon.h" -#include "kai_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla.h" +#include "kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.h" #include "kai_matmul_clamp_f32_bf16p_bf16p_interface.h" #include "kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.h" @@ -41,17 +41,17 @@ inline static float bf16_to_float(const uint16_t* v) { namespace { /// Micro-kernel interface constexpr kai_matmul_clamp_f32_bf16p_bf16p_ukernel ukernel{ - kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, - kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, - kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, - kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, - kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, - kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, - kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, - kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, - kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, - kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, - kai_run_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla}; + kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + kai_run_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla}; /// @brief Truncate the 32-bit floating point number's least significant 16 mantissa bits /// @param x floating-point number @@ -279,6 +279,8 @@ int main() { const auto time_matmul = std::chrono::duration_cast(timer_matmul_end - timer_matmul_start); + int ret = 0; + #ifdef KAI_DEBUG int num_lhs_rows = (M + mr - 1) / mr; int num_lhs_cols = mr * kai_roundup(K, kr); @@ -292,13 +294,13 @@ int main() { const bool is_valid = is_output_correct(M, N, rel_tolerance, dst_ref, dst); std::cout << "TEST[matmul_clamp_f32_bf16p_bf16p]\n"; - std::cout << "- ukernel: matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla\n"; + std::cout << "- ukernel: matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla\n"; if (is_valid) { std::cout << "- Status: PASSED\n"; std::cout << "- Performance: " << time_matmul.count() << "ns\n"; } else { std::cout << "- Status: FAILED\n"; - return 1; + ret = 1; } //----------- END MICRO-KERNELS TESTS @@ -313,7 +315,7 @@ int main() { delete[] dst; delete[] dst_ref; - return 0; + return ret; } #endif // Architectural features check. diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index ed3df38f..66a3c386 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -41,8 +41,8 @@ kai_c_library( kai_c_library( name = "clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla", - srcs = ["matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla.c"], - hdrs = ["matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla.h"], + srcs = ["matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.c"], + hdrs = ["matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.h"], cpu_uarch = kai_cpu_bf16(), deps = [ ":clamp_f32_bf16p_bf16p_interface", diff --git a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla.c b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.c similarity index 95% rename from kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla.c rename to kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.c index 13c2ea67..929e3753 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.c @@ -8,7 +8,7 @@ #error This file must be compiled for AArch64, FEAT_BF16. #else // Architectural features check. -#include "kai_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla.h" +#include "kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.h" #include #include @@ -21,43 +21,43 @@ static const size_t kai_nr = 12; static const size_t kai_kr = 4; static const size_t kai_sr = 1; -size_t kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla(void) { +size_t kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void) { return kai_mr; } -size_t kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla(void) { +size_t kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void) { return kai_nr; } -size_t kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla(void) { +size_t kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void) { return kai_mr; } -size_t kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla(void) { +size_t kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void) { return kai_nr; } -size_t kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla(void) { +size_t kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void) { return kai_kr; } -size_t kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla(void) { +size_t kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void) { return kai_sr; } -size_t kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla(size_t m_idx, size_t k) { +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(size_t m_idx, size_t k) { KAI_ASSUME(m_idx % kai_mr == 0); return m_idx * kai_roundup(k, kai_kr) * sizeof(uint16_t); } -size_t kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla(size_t n_idx, size_t k) { +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(size_t n_idx, size_t k) { KAI_ASSUME(n_idx % kai_nr == 0); return n_idx * (sizeof(float) + kai_roundup(k, kai_kr) * sizeof(uint16_t)); } -size_t kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla( +size_t kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla( size_t m_idx, size_t n_idx, size_t stride) { KAI_ASSUME(m_idx % kai_mr == 0); KAI_ASSUME(n_idx % kai_nr == 0); @@ -65,11 +65,11 @@ size_t kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mm return m_idx * stride + n_idx * sizeof(float); } -size_t kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla(size_t m, size_t n) { +size_t kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(size_t m, size_t n) { return m * n * sizeof(float); } -void kai_run_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla( +void kai_run_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla( size_t m, size_t n, size_t k, // const void* lhs_packed, // const void* rhs_packed, // diff --git a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla.h b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.h similarity index 78% rename from kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla.h rename to kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.h index 812a289f..e870fb2a 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.h @@ -24,42 +24,42 @@ extern "C" { /// The starting row index must be divisible by `m_step`. /// /// @return The m step value. -size_t kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla(void); +size_t kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void); /// Gets n step value. /// /// The starting column index must be divisible by `n_step`. /// /// @return The n step value. -size_t kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla(void); +size_t kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void); /// Gets mr value. /// /// This is the packing parameter which must be used to pack the LHS matrix. /// /// @return The mr value. -size_t kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla(void); +size_t kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void); /// Gets nr value. /// /// This is the packing parameter which must be used to pack the RHS matrix. /// /// @return The nr value. -size_t kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla(void); +size_t kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void); /// Gets kr value. /// /// This is the packing parameter which must be used to pack the LHS & RHS matrices. /// /// @return The kr value. -size_t kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla(void); +size_t kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void); /// Gets sr value. /// /// This is the packing parameter which must be used to pack the RHS matrix. /// /// @return The sr value. -size_t kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla(void); +size_t kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void); /// Gets the offset in bytes to the data element in the packed LHS matrix buffer. /// @@ -67,7 +67,7 @@ size_t kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla(void) /// @param[in] k Number of columns. /// /// @return The offset in bytes to the data element. -size_t kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla(size_t m_idx, size_t k); +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(size_t m_idx, size_t k); /// Gets the offset in bytes to the data element in the packed RHS matrix buffer. /// @@ -75,7 +75,7 @@ size_t kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_ /// @param[in] k Number of columns. /// /// @return The offset in bytes to the data element. -size_t kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla(size_t n_idx, size_t k); +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(size_t n_idx, size_t k); /// Gets the offset in bytes to the data element in the destination matrix buffer. /// @@ -84,8 +84,7 @@ size_t kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_ /// @param[in] stride Row stride in bytes. /// /// @return The offset in bytes to the data element. -size_t kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla( - size_t m_idx, size_t n_idx, size_t stride); +size_t kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(size_t m_idx, size_t n_idx, size_t stride); /// Gets the size in bytes of the destination matrix buffer. /// @@ -93,16 +92,16 @@ size_t kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mm /// @param[in] n Number of columns. /// /// @return The size in bytes of the destination matrix buffer. -size_t kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla(size_t m, size_t n); +size_t kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(size_t m, size_t n); /// Runs the matrix multiplication microkernel followed by a clamp operation. /// /// The pointer of each buffers (packed LHS, packed RHS and output) needs to be added with offset /// calculated using the following functions: /// -/// * Packed LHS: @ref kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla. -/// * Packed RHS: @ref kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla. -/// * Output: @ref kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla. +/// * Packed LHS: @ref kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla. +/// * Packed RHS: @ref kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla. +/// * Output: @ref kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla. /// /// @param[in] m Number of output rows to be computed. /// @param[in] n Number of output columns to be computed. @@ -114,7 +113,7 @@ size_t kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla /// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(float) /// @param[in] clamp_min Minimum value to clamp the final result. /// @param[in] clamp_max Maximum value to clamp the final result. -void kai_run_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla( +void kai_run_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla( size_t m, size_t n, size_t k, // const void* lhs_packed, // const void* rhs_packed, // diff --git a/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp index 23690cce..730ff5ae 100644 --- a/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp +++ b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp @@ -32,7 +32,7 @@ #include "test/reference/pack.hpp" // matmul_clamp_f32_bf16p_bf16p -#include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.h" #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p_f32_neon.h" #include "kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.h" namespace kai::test { @@ -60,34 +60,33 @@ const std::array matmul_methods = { .bias_format = DataFormat(DataType::FP32), .fn_is_supported = cpu_has_bf16, - .fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, - .fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, - .fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, - .fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, + .fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, - .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, + .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon, - .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, + .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, .fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p_f32_neon, .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p_f32_neon, - .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, + .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, .fn_pack_lhs = kai_run_lhs_quant_pack_bf16p_f32_neon, .fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon, .fn_get_packed_rhs_size = nullptr, .fn_get_packed_rhs_size_generic_block_size = kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon, .fn_get_pack_rhs_packed_rhs_offset = nullptr, - .fn_get_main_packed_rhs_offset = - kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, + .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, .fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon, .fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon, - .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, - .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, + .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, - .fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, + .fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, }, MatMulMethod{ .name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla_opt_bias", @@ -109,34 +108,33 @@ const std::array matmul_methods = { .bias_format = DataFormat(DataType::UNKNOWN), .fn_is_supported = cpu_has_bf16, - .fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, - .fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, - .fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, - .fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, + .fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, - .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, + .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon, - .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, + .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, .fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p_f32_neon, .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p_f32_neon, - .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, + .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, .fn_pack_lhs = kai_run_lhs_quant_pack_bf16p_f32_neon, .fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon, .fn_get_packed_rhs_size = nullptr, .fn_get_packed_rhs_size_generic_block_size = kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon, .fn_get_pack_rhs_packed_rhs_offset = nullptr, - .fn_get_main_packed_rhs_offset = - kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, + .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, .fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon, .fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon, - .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, - .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, + .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, - .fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p_bf16p12x4biasf32_8x12x4_neon_mmla, + .fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, }}; } // namespace -- GitLab