diff --git a/CMakeLists.txt b/CMakeLists.txt index 748a9ae3c8fb3bb3ed72b689d4e85d3d2dab4f76..e3c621c58c9abf02e060a2814efe2215aa9af33c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -89,6 +89,14 @@ set(KLEIDIAI_FILES_NEON_FP16 kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.c ) +set(KLEIDIAI_FILES_NEON_BF16 + kai/ukernels/matmul/pack/kai_lhs_pack_f32p8x4_bf16_neon.c + kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon.c + kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p4x24biasf32_f32_bf16_neon.c + kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.c + kai/ukernels/matmul/matmul_clamp_f32_f32_bf16p/kai_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla.c +) + set(KLEIDIAI_FILES_NEON kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.c kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.c @@ -134,6 +142,7 @@ target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_SCALAR}) if((CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") AND NOT MSVC) target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON}) target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_FP16}) + target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_BF16}) target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_DOTPROD}) target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_I8MM}) target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_SME}) @@ -142,6 +151,7 @@ if((CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL set_source_files_properties(${KLEIDIAI_FILES_SCALAR} PROPERTIES COMPILE_OPTIONS -march=armv8-a${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set_source_files_properties(${KLEIDIAI_FILES_NEON} PROPERTIES COMPILE_OPTIONS -march=armv8-a${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set_source_files_properties(${KLEIDIAI_FILES_NEON_FP16} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+fp16${KLEIDIAI_INTERNAL_EXTRA_ARCH}) + set_source_files_properties(${KLEIDIAI_FILES_NEON_BF16} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+bf16${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set_source_files_properties(${KLEIDIAI_FILES_NEON_DOTPROD} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+dotprod${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set_source_files_properties(${KLEIDIAI_FILES_NEON_I8MM} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+i8mm${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set_source_files_properties(${KLEIDIAI_FILES_SME} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+sve+sve2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) @@ -167,6 +177,7 @@ if(KLEIDIAI_BUILD_TESTS) test/common/printer.cpp test/common/int4.cpp test/common/compare.cpp + test/common/matmul_test_common.cpp test/common/matrix_portion.cpp test/common/rect.cpp test/common/round.cpp @@ -201,6 +212,8 @@ if(KLEIDIAI_BUILD_TESTS) test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp test/tests/matmul_clamp_f32_qsi8d32p_qsi4c32p_test.cpp test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp + test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp + test/tests/matmul_clamp_f32_f32_bf16p_test.cpp ) target_link_libraries(kleidiai_test diff --git a/examples/matmul_clamp_f32_bf16p_bf16p/CMakeLists.txt b/examples/matmul_clamp_f32_bf16p_bf16p/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..62007cf7531b05642ca0b9a23e87299b2082e79f --- /dev/null +++ b/examples/matmul_clamp_f32_bf16p_bf16p/CMakeLists.txt @@ -0,0 +1,34 @@ +# +# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +# +# SPDX-License-Identifier: Apache-2.0 +# + +cmake_minimum_required(VERSION 3.16) + +set(CMAKE_CXX_STANDARD 17) +set(KLEIDIAI_PATH ../../) +set(MATMUL_PACK_PATH ${KLEIDIAI_PATH}/kai/ukernels/matmul/pack/) +set(MATMUL_PATH ${KLEIDIAI_PATH}/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/) + +# KleidiAI include directories +include_directories( + ${KLEIDIAI_PATH} + ${MATMUL_PACK_PATH} + ${MATMUL_PATH}) + +# Files requires to build the executable +add_executable(matmul_clamp_f32_bf16p_bf16p + matmul_clamp_f32_bf16p_bf16p.cpp + ${MATMUL_PATH}/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.c + ${MATMUL_PACK_PATH}/kai_lhs_pack_f32p8x4_bf16_neon.c + ${MATMUL_PACK_PATH}/kai_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon.c +) + +target_compile_options(matmul_clamp_f32_bf16p_bf16p + PRIVATE -march=armv8.2-a+bf16 +) + +target_compile_definitions(matmul_clamp_f32_bf16p_bf16p + PRIVATE $<$:KAI_DEBUG> +) diff --git a/examples/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p.cpp b/examples/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p.cpp new file mode 100644 index 0000000000000000000000000000000000000000..49f5a1f9f46d9d38dd7bf2cf3a78c3d196ba7bfe --- /dev/null +++ b/examples/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p.cpp @@ -0,0 +1,331 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +// Example usage for matrix multiplication of two half precision floating-point (FP16) matrices and the accumulation of +// the result into an FP16 destination matrix. +// +// The activations and the weights, stored in the LHS and RHS matrices respectively, are both non-transposed matrices. +// The matrix multiplication computation is performed using floating-point fused multiply-add to accumulator (FMLA) +// vector instructions present in the FEAT_FP16 ArmĀ® architecture feature. +// +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_FP16_SCALAR_ARITHMETIC) || \ + !defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#error This file must be compiled for AArch64, FEAT_FP16. +#else +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +// Include micro-kernel variants +#include "kai_lhs_pack_f32p8x4_bf16_neon.h" +#include "kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.h" +#include "kai_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon.h" +#include "matmul_clamp_f32_bf16p_bf16p_interface.h" + +inline float bf16_to_float(uint16_t v) { + const uint32_t lv = (v << 16); + float fp; + memcpy(&fp, &lv, sizeof(lv)); + return fp; +} + +inline float bf16_to_float(const bfloat16_t* v) { + const uint16_t uint_rep = *reinterpret_cast(v); + return bf16_to_float(uint_rep); +} + +namespace { +/// Micro-kernel interface +constexpr kai_matmul_clamp_f32_bf16p_bf16p_ukernel ukernel{ + kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + kai_get_lhs_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + kai_run_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla}; + +float truncate(float x) { + uint32_t uval = (*reinterpret_cast(&x) & 0xffff0000); + return *reinterpret_cast(&uval); +} + +/// Reference implementation of matrix multiplication +void run_matmul_ref( + size_t m, size_t n, size_t k, const float* lhs, const float* rhs, const float* bias, float* dst, float scalar_min, + float scalar_max) { + for (size_t row_idx = 0; row_idx < m; ++row_idx) { + for (size_t col_idx = 0; col_idx < n; ++col_idx) { + float acc = bias[col_idx]; + + for (size_t k_idx = 0; k_idx < k; ++k_idx) { + float lhs_val = truncate(lhs[row_idx * k + k_idx]); + float rhs_val = truncate(rhs[col_idx + n * k_idx]); + + acc += lhs_val * rhs_val; + } + acc = std::max(acc, scalar_min); + acc = std::min(acc, scalar_max); + + dst[row_idx * n + col_idx] = acc; + } + } +} + +/// Fills the matrix with incremental values +void fill_matrix(size_t num_rows, size_t num_cols, float* dst, const float weight) { + for (size_t i = 0; i < num_rows * num_cols; i++) { + dst[i] = float((i + 1) * weight); + } +} + +void fill_identity(size_t num_rows, size_t num_cols, float* dst, const float weight) { + for (size_t i = 0; i < num_rows * num_cols; i++) { + int col = i % num_cols; + int row = i / num_cols; + + dst[i] = (col == row ? 1.f : 0.f); + } +} + +/// Print the matrix +void print_matrix(size_t num_rows, size_t num_cols, const char* name, const float* src) { + std::cout << name << " = [\n"; + for (size_t y = 0; y < num_rows; ++y) { + std::cout << " ["; + for (size_t x = 0; x < num_cols; ++x) { + std::cout << std::setprecision(2) << std::fixed << src[y * num_cols + x] << ", "; + } + std::cout << ("],\n"); + } + std::cout << ("]\n\n"); +} + +void print_matrix(size_t num_rows, size_t num_cols, const char* name, const bfloat16_t* src) { + std::cout << name << " = [\n"; + for (size_t y = 0; y < num_rows; ++y) { + std::cout << " ["; + for (size_t x = 0; x < num_cols; ++x) { + std::cout << std::setprecision(2) << std::fixed << bf16_to_float(&src[y * num_cols + x]) << ", "; + } + std::cout << ("],\n"); + } + std::cout << ("]\n\n"); +} + +void print_mixed_prec_matrix( + size_t num_rows, size_t num_cols, const char* name, const uint8_t* src, int nr, int stride) { + std::cout << name << " = [\n"; + for (size_t y = 0; y < num_rows; ++y) { + std::cout << " ["; + const uint8_t* src_row = src + stride * y; + for (size_t x = 0; x < num_cols; ++x) { + if (x >= nr) { + // print bfloat + const bfloat16_t* src_elm = + reinterpret_cast(src_row + nr * sizeof(float) + (x - nr) * sizeof(bfloat16_t)); + std::cout << std::setprecision(2) << std::fixed << bf16_to_float(src_elm) << ", "; + } else { + // print float + const float* src_elm = reinterpret_cast(src_row + x * sizeof(float)); + std::cout << std::setprecision(2) << std::fixed << *src_elm << ", "; + } + } + std::cout << ("],\n"); + } + std::cout << ("]\n\n"); +} + +void print_bf_matrix(size_t num_rows, size_t num_cols, const char* name, const float* src) { + std::cout << name << " = [\n"; + for (size_t y = 0; y < num_rows; ++y) { + std::cout << " ["; + for (size_t x = 0; x < num_cols; ++x) { + std::cout << std::setprecision(2) << std::fixed << truncate(src[y * num_cols + x]) << ", "; + } + std::cout << ("],\n"); + } + std::cout << ("]\n\n"); +} + +/// Verify the micro-kernel output matches the reference implementation +bool is_output_correct( + size_t num_rows, size_t num_cols, const float rel_tolerance, const float* ref, const float* act) { + bool is_valid = true; + + for (size_t i = 0; i < num_rows * num_cols; ++i) { + if (std::fabs(ref[i] - act[i]) / (act[i] + 1e-10) > rel_tolerance) { + const size_t x = i % num_cols; + const size_t y = i / num_cols; + + std::cout << std::setprecision(5) << std::fixed << "ERROR![" << y << "][" << x << "]: ref=" << ref[i] + << " vs. act=" << act[i] << "\n"; + + is_valid = false; + } + } + return is_valid; +} +} // namespace + +int main() { + // Parameters of the matrix multiplication. Change these values to see how the micro-kernels operate on different + // sized matrices + const size_t M = 5; // Rows of LHS and DST matrices + const size_t N = 8; // Columns of RHS and DST matrices, and length of the Bias vector. + const size_t K = 7; // Columns of LHS, rows of RHS matrices + + const size_t lhs_size = M * K; + const size_t rhs_size = N * K; + const size_t bias_size = N; + const size_t dst_size = M * N; + + // Allocate the memory + float* lhs = new float[lhs_size]; + float* rhs = new float[rhs_size]; + float* bias = new float[bias_size]; + + fill_matrix(M, K, lhs, 0.1); + fill_matrix(K, N, rhs, 0.1); + fill_matrix(1, N, bias, 1); + +#ifdef KAI_DEBUG + // std::cout << "Floats: " << std::endl; + print_matrix(M, K, "lhs", lhs); + print_matrix(K, N, "rhs", rhs); + print_matrix(1, N, "bias", bias); + + // Print bf16 converted values + print_bf_matrix(M, K, "lhs_bf", lhs); + print_bf_matrix(K, N, "rhs_bf", rhs); +#endif // KAI_DEBUG + + //----------- REFERENCE IMPLEMENTATION + //------------------------------------ + //------------------------------------ + float* dst_ref = new float[dst_size]; + + run_matmul_ref( + M, N, K, // Dimensions + lhs, // LHS buffer + rhs, // RHS buffer + bias, // Bias buffer + dst_ref, // DST + FLT_MIN, FLT_MAX // Min and max for the clamp operation + ); + //----------- END REFERENCE IMPLEMENTATION + //------------------------------------ + //------------------------------------ + + //----------- MICRO-KERNELS TESTS + //------------------------------------ + //------------------------------------ + const size_t mr = ukernel.get_mr(); + const size_t nr = ukernel.get_nr(); + const size_t kr = ukernel.get_kr(); + const size_t sr = ukernel.get_sr(); + + // In a single row, we pack nr bias values followed by K rows of nr RHS values + const size_t rhs_packed_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon(N, K); + uint8_t* rhs_packed = new uint8_t[rhs_packed_size]; + + const size_t lhs_stride = K * sizeof(float); + const size_t rhs_stride = N * sizeof(float); + const size_t dst_stride_row = N * sizeof(float); + const size_t dst_stride_col = sizeof(float); + + const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_pack_f32p8x4_bf16_neon(M, K, mr, kr, sr); + bfloat16_t* lhs_packed = new bfloat16_t[lhs_packed_size]; + + // Packing only needs to be performed once if the contents of the bias and RHS matrices are expected to be constant. + kai_run_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon( + 1, N, K, nr, kr, sr, // Packing arguments + rhs_stride, // RHS stride + rhs, // RHS + bias, // Bias + NULL, // Scale + rhs_packed, // RHS packed + 0, NULL); + + // The RHS and Bias buffers can be freed after packing, however we reuse them for the reference test below + +#ifdef KAI_DEBUG + const size_t rhs_packed_cols = nr + kai_roundup(K, kr) * nr; + + // Each col has nr floats and then K*nr bfloats + int rhs_packed_stride = nr * sizeof(float) + kai_roundup(K, kr) * nr * sizeof(bfloat16_t); + const size_t rhs_packed_rows = rhs_packed_size / rhs_packed_stride; + + print_mixed_prec_matrix(rhs_packed_rows, rhs_packed_cols, "rhs_packed", rhs_packed, nr, rhs_packed_stride); +#endif // KAI_DEBUG + + float* dst = new float[dst_size]; + + kai_run_lhs_pack_f32p8x4_bf16_neon(M, K, mr, kr, sr, 0 /* m_idx_start */, lhs, lhs_stride, lhs_packed); + + const auto timer_matmul_start = std::chrono::high_resolution_clock::now(); + + ukernel.run_matmul( + M, N, K, // Dimensions + lhs_packed, // LHS packed + rhs_packed, // RHS packed + dst, // DST + dst_stride_row, // DST stride (row) + dst_stride_col, // DST stride (col) + FLT_MIN, FLT_MAX // Min and max for the clamp operation + ); + + const auto timer_matmul_end = std::chrono::high_resolution_clock::now(); + const auto time_matmul = + std::chrono::duration_cast(timer_matmul_end - timer_matmul_start); + +#ifdef KAI_DEBUG + int num_lhs_rows = (M + mr - 1) / mr; + int num_lhs_cols = mr * kai_roundup(K, kr); + + print_matrix(num_lhs_rows, num_lhs_cols, "lhs_packed", lhs_packed); + print_matrix(M, N, "dst", dst); + print_matrix(M, N, "ref", dst_ref); +#endif // KAI_DEBUG + + const bool is_valid = is_output_correct(M, N, 0.02 /* rel tol */, dst_ref, dst); + + std::cout << "TEST[matmul_clamp_f32_bf16p_bf16p]\n"; + std::cout << "- ukernel: matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla\n"; + if (is_valid) { + std::cout << "- Status: PASSED\n"; + std::cout << "- Performance: " << time_matmul.count() << "ns\n"; + } else { + std::cout << "- Status: FAILED\n"; + return 1; + } + + //----------- END MICRO-KERNELS TESTS + //------------------------------------ + //------------------------------------ + + delete[] lhs; + delete[] rhs; + delete[] bias; + delete[] rhs_packed; + delete[] dst; + delete[] dst_ref; + + return 0; +} + +#endif // Architectural features check. diff --git a/examples/matmul_clamp_f32_f32_bf16p/CMakeLists.txt b/examples/matmul_clamp_f32_f32_bf16p/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..7dffb66673fd0145ad0173ec41ed066055b72f63 --- /dev/null +++ b/examples/matmul_clamp_f32_f32_bf16p/CMakeLists.txt @@ -0,0 +1,33 @@ +# +# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +# +# SPDX-License-Identifier: Apache-2.0 +# + +cmake_minimum_required(VERSION 3.16) + +set(CMAKE_CXX_STANDARD 17) +set(KLEIDIAI_PATH ../../) +set(MATMUL_PACK_PATH ${KLEIDIAI_PATH}/kai/ukernels/matmul/pack/) +set(MATMUL_PATH ${KLEIDIAI_PATH}/kai/ukernels/matmul/matmul_clamp_f32_f32_bf16p/) + +# KleidiAI include directories +include_directories( + ${KLEIDIAI_PATH} + ${MATMUL_PACK_PATH} + ${MATMUL_PATH}) + +# Files requires to build the executable +add_executable(matmul_clamp_f32_f32_bf16p + matmul_clamp_f32_f32_bf16p.cpp + ${MATMUL_PATH}/kai_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla.c + ${MATMUL_PACK_PATH}/kai_rhs_pack_kxn_f32p4x24biasf32_f32_bf16_neon.c +) + +target_compile_options(matmul_clamp_f32_f32_bf16p + PRIVATE -march=armv8.6-a+fp16+bf16 +) + +target_compile_definitions(matmul_clamp_f32_f32_bf16p + PRIVATE $<$:KAI_DEBUG> +) diff --git a/examples/matmul_clamp_f32_f32_bf16p/matmul_clamp_f32_f32_bf16p.cpp b/examples/matmul_clamp_f32_f32_bf16p/matmul_clamp_f32_f32_bf16p.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d31c99dc76849f73520d0504ec5c158753a2a076 --- /dev/null +++ b/examples/matmul_clamp_f32_f32_bf16p/matmul_clamp_f32_f32_bf16p.cpp @@ -0,0 +1,346 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +// Example usage for matrix multiplication of two single precision floating-point (FP32) matrices and the accumulation +// of the result into an FP32 destination matrix. +// +// The activations and the weights, stored in the LHS and RHS matrices respectively, are both non-transposed matrices. +// The matrix multiplication computation is performed using vector instructions present in the FEAT_BF16 ArmĀ® +// architecture feature. +// +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_BF16_SCALAR_ARITHMETIC) || \ + !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) +#error This file must be compiled for AArch64, FEAT_BF16. +#else +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +// Include micro-kernel variants +#include "kai/kai_common.h" +#include "kai_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla.h" +#include "kai_rhs_pack_kxn_f32p4x24biasf32_f32_bf16_neon.h" +#include "matmul_clamp_f32_f32_bf16p_interface.h" + +#define FLOAT16_MIN (FLT_MIN) +#define FLOAT16_MAX (FLT_MAX) + +/** Convert bfloat16 to float + * + * @param[in] v Bfloat16 value to convert to float + * + * @return Converted value + */ +inline float bf16_to_float(const bfloat16_t* v) { + const uint32_t lv = ((*reinterpret_cast(v)) << 16); + float fp; + memcpy(&fp, &lv, sizeof(lv)); + return fp; +} + +inline float bf16_to_float(uint16_t v) { + const uint32_t lv = (v << 16); + float fp; + memcpy(&fp, &lv, sizeof(lv)); + return fp; +} + +namespace { +/// Micro-kernel interface + +// a64 +constexpr kai_matmul_clamp_f32_f32_bf16p_ukernel ukernel{ + kai_get_m_step_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla, + kai_get_n_step_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla, + kai_get_mr_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla, + kai_get_nr_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla, + kai_get_kr_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla, + kai_get_sr_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla, + kai_get_lhs_offset_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla, + kai_get_rhs_packed_offset_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla, + kai_get_dst_offset_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla, + kai_get_dst_size_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla, + kai_run_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla}; + +float truncate(float x) { + uint32_t uval = (*reinterpret_cast(&x) & 0xffff0000); + return *reinterpret_cast(&uval); +} + +/// Reference implementation of matrix multiplication +void run_matmul_ref( + size_t m, size_t n, size_t k, const float* lhs, const float* rhs, const float* bias, float* dst, float scalar_min, + float scalar_max) { + for (size_t row_idx = 0; row_idx < m; ++row_idx) { + for (size_t col_idx = 0; col_idx < n; ++col_idx) { + float acc = bias[col_idx]; + + for (size_t k_idx = 0; k_idx < k; ++k_idx) { + float lhs_val = truncate(lhs[row_idx * k + k_idx]); + float rhs_val = truncate(rhs[col_idx + n * k_idx]); + + acc += lhs_val * rhs_val; + } + acc = std::max(acc, scalar_min); + acc = std::min(acc, scalar_max); + + dst[row_idx * n + col_idx] = acc; + } + } +} + +/// Fills the matrix with incremental values +void fill_matrix(size_t num_rows, size_t num_cols, float* dst, const float weight) { + for (size_t i = 0; i < num_rows * num_cols; i++) { + dst[i] = float((i + 1) * weight); + } +} + +void fill_identity(size_t num_rows, size_t num_cols, float* dst, const float weight) { + for (size_t i = 0; i < num_rows * num_cols; i++) { + int col = i % num_cols; + int row = i / num_cols; + + dst[i] = (col == row ? 1.f : 0.f); + } +} + +/// Print the matrix +void print_matrix(size_t num_rows, size_t num_cols, const char* name, const float* src) { + std::cout << name << " = [\n"; + for (size_t y = 0; y < num_rows; ++y) { + std::cout << " ["; + for (size_t x = 0; x < num_cols; ++x) { + std::cout << std::setprecision(2) << std::fixed << src[y * num_cols + x] << ", "; + } + std::cout << ("],\n"); + } + std::cout << ("]\n\n"); +} + +void print_matrix(size_t num_rows, size_t num_cols, const char* name, const bfloat16_t* src) { + std::cout << name << " = [\n"; + for (size_t y = 0; y < num_rows; ++y) { + std::cout << " ["; + for (size_t x = 0; x < num_cols; ++x) { + std::cout << std::setprecision(2) << std::fixed << bf16_to_float(&src[y * num_cols + x]) << ", "; + } + std::cout << ("],\n"); + } + std::cout << ("]\n\n"); +} + +void print_mixed_prec_matrix( + size_t num_rows, size_t num_cols, const char* name, const bfloat16_t* src, int nr, int stride) { + std::cout << name << " = [\n"; + for (size_t y = 0; y < num_rows; ++y) { + std::cout << " ["; + uint8_t* src_row = ((uint8_t*)src) + stride * y; + for (size_t x = 0; x < num_cols; ++x) { + if (x >= nr) { + // print bfloat + bfloat16_t* src_elm = + reinterpret_cast(src_row + nr * sizeof(float) + (x - nr) * sizeof(bfloat16_t)); + std::cout << std::setprecision(2) << std::fixed << bf16_to_float(src_elm) << ", "; + } else { + // print float + float* src_elm = reinterpret_cast(src_row + x * sizeof(float)); + std::cout << std::setprecision(2) << std::fixed << *src_elm << ", "; + } + } + std::cout << ("],\n"); + } + std::cout << ("]\n\n"); +} + +void print_bf_matrix(size_t num_rows, size_t num_cols, const char* name, const float* src) { + std::cout << name << " = [\n"; + for (size_t y = 0; y < num_rows; ++y) { + std::cout << " ["; + for (size_t x = 0; x < num_cols; ++x) { + std::cout << std::setprecision(2) << std::fixed << truncate(src[y * num_cols + x]) << ", "; + } + std::cout << ("],\n"); + } + std::cout << ("]\n\n"); +} + +// void print_matrix(size_t num_rows, size_t num_cols, const char* name, const bfloat16_t* src) { +// std::cout << name << " = [\n"; +// for (size_t y = 0; y < num_rows; ++y) { +// std::cout << " ["; +// for (size_t x = 0; x < num_cols; ++x) { +// std::cout << std::setprecision(2) << std::fixed << src[y * num_cols + x] << ", "; +// } +// std::cout << ("],\n"); +// } +// std::cout << ("]\n\n"); +// } + +/// Verify the micro-kernel output matches the reference implementation +bool is_output_correct( + size_t num_rows, size_t num_cols, const float rel_tolerance, const float* ref, const float* act) { + bool is_valid = true; + + for (size_t i = 0; i < num_rows * num_cols; ++i) { + if (std::fabs(ref[i] - act[i]) / (act[i] + 1e-10) > rel_tolerance) { + const size_t x = i % num_cols; + const size_t y = i / num_cols; + + std::cout << std::setprecision(5) << std::fixed << "ERROR![" << y << "][" << x << "]: ref=" << ref[i] + << " vs. act=" << act[i] << "\n"; + + is_valid = false; + } + } + return is_valid; +} +} // namespace + +int main() { + // Parameters of the matrix multiplication. Change these values to see how the micro-kernels operate on different + // sized matrices + const size_t M = 4; // Rows of LHS and DST matrices + const size_t N = 10; // Columns of RHS and DST matrices, and length of the Bias vector. + const size_t K = 5; // Columns of LHS, rows of RHS matrices + + const size_t lhs_size = M * K; + const size_t rhs_size = N * K; + const size_t bias_size = N; + const size_t dst_size = M * N; + + // Allocate the memory + float* lhs = new float[lhs_size]; + float* rhs = new float[rhs_size]; + float* bias = new float[bias_size]; + + fill_matrix(M, K, lhs, 0.1); + fill_matrix(K, N, rhs, 0.1); + fill_matrix(1, N, bias, 1); + +#ifdef KAI_DEBUG + print_matrix(M, K, "lhs", lhs); + print_matrix(K, N, "rhs", rhs); + print_matrix(1, N, "bias", bias); + + // Print bf16 converted values + print_bf_matrix(M, K, "lhs", lhs); + print_bf_matrix(K, N, "rhs", rhs); +#endif // KAI_DEBUG + + //----------- REFERENCE IMPLEMENTATION + //------------------------------------ + //------------------------------------ + float* dst_ref = new float[dst_size]; + + run_matmul_ref( + M, N, K, // Dimensions + lhs, // LHS buffer + rhs, // RHS buffer + bias, // Bias buffer + dst_ref, // DST + FLOAT16_MIN, FLOAT16_MAX // Min and max for the clamp operation + ); + //----------- END REFERENCE IMPLEMENTATION + //------------------------------------ + //------------------------------------ + + //----------- MICRO-KERNELS TESTS + //------------------------------------ + //------------------------------------ + const size_t mr = ukernel.get_mr(); + const size_t nr = ukernel.get_nr(); + const size_t kr = ukernel.get_kr(); + const size_t sr = ukernel.get_sr(); + + // In a single row, we pack nr bias values followed by K rows of nr RHS values + const size_t rhs_packed_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p4x24biasf32_f32_bf16_neon(N, K); + const size_t rhs_packed_cols = nr + kai_roundup(K, kr) * nr; + + int rhs_packed_stride = nr * sizeof(float) + kai_roundup(K, kr) * nr * sizeof(bfloat16_t); + + // Each col has nr floats and then K*nr bfloats + bfloat16_t* rhs_packed = new bfloat16_t[rhs_packed_size]; + + const size_t lhs_stride = K * sizeof(float); + const size_t rhs_stride = N * sizeof(float); + const size_t dst_stride_row = N * sizeof(float); + const size_t dst_stride_col = sizeof(float); + + // Packing only needs to be performed once if the contents of the bias and RHS matrices are expected to be constant. + kai_run_rhs_pack_kxn_f32p4x24biasf32_f32_bf16_neon( + 1, N, K, nr, kr, sr, // Packing arguments + rhs_stride, // RHS stride + rhs, // RHS + bias, // Bias + NULL, // Scale + rhs_packed, // RHS packed + 0, NULL); + + // The RHS and Bias buffers can be freed after packing, however we reuse them for the reference test below +#ifdef KAI_DEBUG + const size_t rhs_packed_rows = rhs_packed_size / rhs_packed_stride; + print_mixed_prec_matrix(rhs_packed_rows, rhs_packed_cols, "rhs_packed", rhs_packed, nr, rhs_packed_stride); +#endif // KAI_DEBUG + + float* dst = new float[dst_size]; + + const auto timer_matmul_start = std::chrono::high_resolution_clock::now(); + + ukernel.run_matmul( + M, N, K, // Dimensions + lhs, // LHS packed + lhs_stride, // Lhs stride + rhs_packed, // RHS packed + dst, // DST + dst_stride_row, // DST stride (row) + dst_stride_col, // DST stride (col) + FLT_MIN, FLT_MAX // Min and max for the clamp operation + ); + + const auto timer_matmul_end = std::chrono::high_resolution_clock::now(); + const auto time_matmul = + std::chrono::duration_cast(timer_matmul_end - timer_matmul_start); + +#ifdef KAI_DEBUG + print_matrix(M, N, "dst", dst); + print_matrix(M, N, "ref", dst_ref); +#endif // KAI_DEBUG + + const bool is_valid = is_output_correct(M, N, 0.02 /* rel tol */, dst_ref, dst); + + std::cout << "TEST[matmul_clamp_f32_f32_bf16p]\n"; + std::cout << "- ukernel: kai_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla\n"; + if (is_valid) { + std::cout << "- Status: PASSED\n"; + std::cout << "- Performance: " << time_matmul.count() << "ns\n"; + } else { + std::cout << "- Status: FAILED\n"; + return 1; + } + + //----------- END MICRO-KERNELS TESTS + //------------------------------------ + //------------------------------------ + + delete[] lhs; + delete[] rhs; + delete[] bias; + delete[] rhs_packed; + delete[] dst; + delete[] dst_ref; + + return 0; +} + +#endif // Architectural features check. diff --git a/kai/kai_common.h b/kai/kai_common.h index dc815f679dbf369b7de8ae0184aa91b4acdd1a38..018c6ccd703dbed125ba950520ad4bb91483461d 100644 --- a/kai/kai_common.h +++ b/kai/kai_common.h @@ -23,7 +23,7 @@ extern "C" { #define KAI_ERROR(msg) \ do { \ fflush(stdout); \ - fprintf(stderr, "%s:%d %s", __FILE__, __LINE__, msg); \ + fprintf(stderr, "%s:%d %s\n", __FILE__, __LINE__, msg); \ exit(EXIT_FAILURE); \ } while (0) @@ -104,7 +104,7 @@ inline static float kai_cast_f32_bf16(uint16_t bf16) { inline static uint16_t kai_cast_bf16_f32(float f32) { uint16_t bf16 = 0; #ifdef __ARM_FEATURE_BF16 - asm("bfcvt %h[output], %s[input]" : [output] "=w"(bf16) : [input] "w"(f32)); + __asm("bfcvt %h[output], %s[input]" : [output] "=w"(bf16) : [input] "w"(f32)); #else const uint32_t* i32 = (uint32_t*)(&f32); bf16 = (*i32 >> 16); diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 7dbec37504b97a42251383a0e509a4b5ed0eeabf..ef0dbddddbe1e9d99aaa9ba3b4b112460de678d6 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -7,6 +7,7 @@ load( "//:kai_defs.bzl", "kai_c_library", + "kai_cpu_bf16", "kai_cpu_dotprod", "kai_cpu_fp16", "kai_cpu_i8mm", @@ -42,6 +43,22 @@ kai_c_library( ], ) +kai_c_library( + name = "clamp_f32_bf16p_bf16p_interface", + hdrs = ["matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h"], + cpu_uarch = kai_cpu_bf16(), +) + +kai_c_library( + name = "clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla", + srcs = ["matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.c"], + hdrs = ["matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.h"], + cpu_uarch = kai_cpu_bf16(), + deps = [ + ":clamp_f32_bf16p_bf16p_interface", + ], +) + kai_c_library( name = "clamp_f32_f32_f32p", srcs = [ @@ -141,6 +158,13 @@ kai_c_library( cpu_uarch = kai_cpu_sme(), ) +kai_c_library( + name = "lhs_pack_f32p8x4_bf16_neon", + srcs = ["pack/kai_lhs_pack_f32p8x4_bf16_neon.c"], + hdrs = ["pack/kai_lhs_pack_f32p8x4_bf16_neon.h"], + cpu_uarch = kai_cpu_bf16(), +) + kai_c_library( name = "rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon", srcs = ["pack/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.c"], @@ -148,6 +172,13 @@ kai_c_library( cpu_uarch = kai_cpu_fp16(), ) +kai_c_library( + name = "rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon", + srcs = ["pack/kai_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon.c"], + hdrs = ["pack/kai_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon.h"], + cpu_uarch = kai_cpu_bf16(), +) + kai_c_library( name = "rhs_pack_kxn_f32pbiasf32_f32_f32_neon", srcs = ["pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.c"], diff --git a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.c b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.c new file mode 100644 index 0000000000000000000000000000000000000000..13ebc4e306bba8dfb45e186f83500fb3ac167fd3 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.c @@ -0,0 +1,583 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) +#error This file must be compiled for AArch64, FEAT_BF16. +#else // Architectural features check. + +#include +#include +#include +#include +#include + +typedef bfloat16_t bfloat16; + +#include "kai/kai_common.h" +#include "kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.h" + +static const size_t kai_mr = 8; +static const size_t kai_nr = 12; +static const size_t kai_kr = 4; +static const size_t kai_sr = 1; + +size_t kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(void) { + return kai_mr; +} + +size_t kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(void) { + return kai_nr; +} + +size_t kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(void) { + return kai_sr; +} + +size_t kai_get_lhs_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(size_t m_idx, size_t stride) { + KAI_ASSUME(m_idx % kai_mr == 0); + + return m_idx * stride; +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(size_t n_idx, size_t k) { + KAI_ASSUME(n_idx % kai_nr == 0); + + // return n_idx / kai_nr * (kai_nr * sizeof(float) + kai_nr * k * sizeof(bfloat16)); + return n_idx * (sizeof(float) + kai_roundup(k, kai_kr) * sizeof(bfloat16_t)); +} + +size_t kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla( + size_t m_idx, size_t n_idx, size_t stride) { + KAI_ASSUME(m_idx % kai_mr == 0); + KAI_ASSUME(n_idx % kai_nr == 0); + + return m_idx * stride + n_idx * sizeof(float); +} + +size_t kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(size_t m, size_t n) { + return m * n * sizeof(float); +} + +void kai_run_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla( + size_t m, size_t n, size_t k, // + const void* lhs_packed, // + const void* rhs_packed, // + void* dst, size_t dst_stride_row, size_t dst_stride_col, // + float clamp_min, float clamp_max) { + KAI_ASSERT(dst_stride_col == sizeof(float)); + + const void* Apanel = lhs_packed; + // const void *Bpanel = rhs_packed; + void* Cpanel = dst; + size_t ldc = dst_stride_row / sizeof(float); + + size_t M = m; + + typedef struct { + float maxval; + float minval; + size_t N; + size_t K; + const void* Bpanel; + void* output_ptr; + } KernelArgs; + + KernelArgs ka; + + ka.N = n; + ka.K = kai_roundup(k, 4) / 4 - 1; + + ka.Bpanel = rhs_packed; + + // Direct output. + ka.output_ptr = dst; + + // Clamping output. + ka.maxval = clamp_max; + ka.minval = clamp_min; + + __asm__ __volatile__( + "1:" // Height loop + "add x11, %x[Cpanel], %x[ldc], LSL #2\n" + "add x10, %x[Cpanel], %x[ldc], LSL #1\n" + "add x9, x11, %x[ldc], LSL #1\n" + "cmp %x[M], #0x8\n" + "add x28, %x[Cpanel], %x[ldc], LSL #3\n" + "add x27, %x[Cpanel], %x[ldc]\n" + "add x26, x10, %x[ldc]\n" + "add x25, x11, %x[ldc]\n" + "add x24, x9, %x[ldc]\n" + "bge 2f\n" + "cmp %x[M], #0x2\n" + "mov x24, %x[Cpanel]\n" + "csel x27, x27, %x[Cpanel], GE\n" + "csel x10, x10, %x[Cpanel], GT\n" + "cmp %x[M], #0x4\n" + "csel x26, x26, %x[Cpanel], GE\n" + "csel x11, x11, %x[Cpanel], GT\n" + "cmp %x[M], #0x6\n" + "csel x25, x25, %x[Cpanel], GE\n" + "csel x9, x9, %x[Cpanel], GT\n" + "2:" // all rows valid + "ldr x23, [%x[args_ptr], %[offsetof_N]]\n" + "ldr x22, [%x[args_ptr], %[offsetof_Bpanel]]\n" + "mov x21, %x[Apanel]\n" + "3:" // Width loop + "ldr q4, [x22, #0x0]\n" + "ldr q5, [x22, #0x10]\n" + "mov %x[Apanel], x21\n" + "ldr q6, [x22, #0x20]\n" + "ldr x20, [%x[args_ptr], %[offsetof_K]]\n" + "add x22, x22, #0x30\n" + "ldr q7, [x22, #0x0]\n" + "ldr q0, [%x[Apanel], #0x0]\n" + "ldr q1, [%x[Apanel], #0x10]\n" + "zip1 v8.2d, v4.2d, v4.2d\n" + "ldr q2, [%x[Apanel], #0x20]\n" + "zip2 v11.2d, v4.2d, v4.2d\n" + "ldr q4, [x22, #0x10]\n" + "zip1 v9.2d, v5.2d, v5.2d\n" + "zip2 v12.2d, v5.2d, v5.2d\n" + "cmp x20, #0x2\n" + "zip1 v10.2d, v6.2d, v6.2d\n" + "zip2 v13.2d, v6.2d, v6.2d\n" + "prfm pldl1keep, [%x[Apanel], #0x0]\n" + "mov v14.16b, v8.16b\n" + "mov v17.16b, v11.16b\n" + "prfm pldl1keep, [x22, #0x0]\n" + "mov v15.16b, v9.16b\n" + "mov v18.16b, v12.16b\n" + "prfm pldl1keep, [x22, #0x40]\n" + "mov v16.16b, v10.16b\n" + "mov v19.16b, v13.16b\n" + "prfm pldl1keep, [%x[Apanel], #0x40]\n" + "mov v20.16b, v8.16b\n" + "mov v21.16b, v9.16b\n" + "prfm pldl1keep, [x22, #0x80]\n" + "mov v22.16b, v10.16b\n" + "mov v23.16b, v11.16b\n" + "prfm pldl1keep, [%x[Apanel], #0x80]\n" + "mov v24.16b, v12.16b\n" + "mov v25.16b, v13.16b\n" + "prfm pldl1keep, [x22, #0xc0]\n" + "mov v26.16b, v8.16b\n" + "mov v27.16b, v9.16b\n" + "prfm pldl1keep, [x22, #0x100]\n" + "mov v28.16b, v10.16b\n" + "mov v29.16b, v11.16b\n" + "prfm pldl1keep, [%x[Apanel], #0xc0]\n" + "mov v30.16b, v12.16b\n" + "mov v31.16b, v13.16b\n" + "prfm pldl1keep, [x22, #0x140]\n" + "add x22, x22, #0x20\n" + "add %x[Apanel], %x[Apanel], #0x30\n" + "blt 5f\n" + "4:" // main loop head + "ldr q3, [%x[Apanel], #0x0]\n" + "ldr q5, [x22, #0x0]\n" + ".inst 0x6e47ec08 // bfmmla v8.4s, v0.8h, v7.8h\n" + "ldr q6, [x22, #0x10]\n" + ".inst 0x6e44ec0b // bfmmla v11.4s, v0.8h, v4.8h\n" + ".inst 0x6e47ec2e // bfmmla v14.4s, v1.8h, v7.8h\n" + ".inst 0x6e44ec31 // bfmmla v17.4s, v1.8h, v4.8h\n" + ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n" + "sub x20, x20, #0x2\n" + ".inst 0x6e44ec57 // bfmmla v23.4s, v2.8h, v4.8h\n" + ".inst 0x6e47ec7a // bfmmla v26.4s, v3.8h, v7.8h\n" + "ldr q7, [x22, #0x20]\n" + ".inst 0x6e44ec7d // bfmmla v29.4s, v3.8h, v4.8h\n" + "ldr q4, [x22, #0x30]\n" + ".inst 0x6e45ec09 // bfmmla v9.4s, v0.8h, v5.8h\n" + ".inst 0x6e46ec0c // bfmmla v12.4s, v0.8h, v6.8h\n" + ".inst 0x6e45ec2f // bfmmla v15.4s, v1.8h, v5.8h\n" + "cmp x20, #0x2\n" + ".inst 0x6e46ec32 // bfmmla v18.4s, v1.8h, v6.8h\n" + ".inst 0x6e45ec55 // bfmmla v21.4s, v2.8h, v5.8h\n" + "prfm pldl1keep, [%x[Apanel], #0x100]\n" + ".inst 0x6e46ec58 // bfmmla v24.4s, v2.8h, v6.8h\n" + ".inst 0x6e45ec7b // bfmmla v27.4s, v3.8h, v5.8h\n" + "ldr q5, [x22, #0x40]\n" + ".inst 0x6e46ec7e // bfmmla v30.4s, v3.8h, v6.8h\n" + "ldr q6, [x22, #0x50]\n" + ".inst 0x6e47ec0a // bfmmla v10.4s, v0.8h, v7.8h\n" + ".inst 0x6e44ec0d // bfmmla v13.4s, v0.8h, v4.8h\n" + "ldr q0, [%x[Apanel], #0x10]\n" + ".inst 0x6e47ec30 // bfmmla v16.4s, v1.8h, v7.8h\n" + ".inst 0x6e44ec33 // bfmmla v19.4s, v1.8h, v4.8h\n" + "ldr q1, [%x[Apanel], #0x20]\n" + ".inst 0x6e47ec56 // bfmmla v22.4s, v2.8h, v7.8h\n" + ".inst 0x6e44ec59 // bfmmla v25.4s, v2.8h, v4.8h\n" + "ldr q2, [%x[Apanel], #0x30]\n" + ".inst 0x6e47ec7c // bfmmla v28.4s, v3.8h, v7.8h\n" + "ldr q7, [x22, #0x60]\n" + ".inst 0x6e44ec7f // bfmmla v31.4s, v3.8h, v4.8h\n" + "ldr q3, [%x[Apanel], #0x40]\n" + "ldr q4, [x22, #0x70]\n" + ".inst 0x6e45ec08 // bfmmla v8.4s, v0.8h, v5.8h\n" + ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n" + ".inst 0x6e45ec2e // bfmmla v14.4s, v1.8h, v5.8h\n" + ".inst 0x6e46ec31 // bfmmla v17.4s, v1.8h, v6.8h\n" + "prfm pldl1keep, [x22, #0x180]\n" + ".inst 0x6e45ec54 // bfmmla v20.4s, v2.8h, v5.8h\n" + ".inst 0x6e46ec57 // bfmmla v23.4s, v2.8h, v6.8h\n" + "prfm pldl1keep, [x22, #0x1c0]\n" + ".inst 0x6e45ec7a // bfmmla v26.4s, v3.8h, v5.8h\n" + "ldr q5, [x22, #0x80]\n" + ".inst 0x6e46ec7d // bfmmla v29.4s, v3.8h, v6.8h\n" + "ldr q6, [x22, #0x90]\n" + "prfm pldl1keep, [%x[Apanel], #0x140]\n" + ".inst 0x6e47ec09 // bfmmla v9.4s, v0.8h, v7.8h\n" + "prfm pldl1keep, [x22, #0x200]\n" + ".inst 0x6e44ec0c // bfmmla v12.4s, v0.8h, v4.8h\n" + ".inst 0x6e47ec2f // bfmmla v15.4s, v1.8h, v7.8h\n" + ".inst 0x6e44ec32 // bfmmla v18.4s, v1.8h, v4.8h\n" + ".inst 0x6e47ec55 // bfmmla v21.4s, v2.8h, v7.8h\n" + ".inst 0x6e44ec58 // bfmmla v24.4s, v2.8h, v4.8h\n" + ".inst 0x6e47ec7b // bfmmla v27.4s, v3.8h, v7.8h\n" + "ldr q7, [x22, #0xa0]\n" + ".inst 0x6e44ec7e // bfmmla v30.4s, v3.8h, v4.8h\n" + "ldr q4, [x22, #0xb0]\n" + ".inst 0x6e45ec0a // bfmmla v10.4s, v0.8h, v5.8h\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + "ldr q0, [%x[Apanel], #0x50]\n" + ".inst 0x6e45ec30 // bfmmla v16.4s, v1.8h, v5.8h\n" + ".inst 0x6e46ec33 // bfmmla v19.4s, v1.8h, v6.8h\n" + "ldr q1, [%x[Apanel], #0x60]\n" + ".inst 0x6e45ec56 // bfmmla v22.4s, v2.8h, v5.8h\n" + ".inst 0x6e46ec59 // bfmmla v25.4s, v2.8h, v6.8h\n" + "ldr q2, [%x[Apanel], #0x70]\n" + ".inst 0x6e45ec7c // bfmmla v28.4s, v3.8h, v5.8h\n" + ".inst 0x6e46ec7f // bfmmla v31.4s, v3.8h, v6.8h\n" + "add %x[Apanel], %x[Apanel], #0x80\n" + "add x22, x22, #0xc0\n" + "bge 4b\n" + "5:" // main loop skip + "ldr q3, [%x[Apanel], #0x0]\n" + "ldr q5, [x22, #0x0]\n" + ".inst 0x6e47ec08 // bfmmla v8.4s, v0.8h, v7.8h\n" + "ldr q6, [x22, #0x10]\n" + ".inst 0x6e44ec0b // bfmmla v11.4s, v0.8h, v4.8h\n" + ".inst 0x6e47ec2e // bfmmla v14.4s, v1.8h, v7.8h\n" + ".inst 0x6e44ec31 // bfmmla v17.4s, v1.8h, v4.8h\n" + ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n" + "add %x[Apanel], %x[Apanel], #0x10\n" + ".inst 0x6e44ec57 // bfmmla v23.4s, v2.8h, v4.8h\n" + ".inst 0x6e47ec7a // bfmmla v26.4s, v3.8h, v7.8h\n" + "ldr q7, [x22, #0x20]\n" + ".inst 0x6e44ec7d // bfmmla v29.4s, v3.8h, v4.8h\n" + "ldr q4, [x22, #0x30]\n" + ".inst 0x6e45ec09 // bfmmla v9.4s, v0.8h, v5.8h\n" + ".inst 0x6e46ec0c // bfmmla v12.4s, v0.8h, v6.8h\n" + ".inst 0x6e45ec2f // bfmmla v15.4s, v1.8h, v5.8h\n" + "add x22, x22, #0x40\n" + ".inst 0x6e46ec32 // bfmmla v18.4s, v1.8h, v6.8h\n" + ".inst 0x6e45ec55 // bfmmla v21.4s, v2.8h, v5.8h\n" + ".inst 0x6e46ec58 // bfmmla v24.4s, v2.8h, v6.8h\n" + ".inst 0x6e45ec7b // bfmmla v27.4s, v3.8h, v5.8h\n" + ".inst 0x6e46ec7e // bfmmla v30.4s, v3.8h, v6.8h\n" + ".inst 0x6e47ec0a // bfmmla v10.4s, v0.8h, v7.8h\n" + ".inst 0x6e44ec0d // bfmmla v13.4s, v0.8h, v4.8h\n" + ".inst 0x6e47ec30 // bfmmla v16.4s, v1.8h, v7.8h\n" + ".inst 0x6e44ec33 // bfmmla v19.4s, v1.8h, v4.8h\n" + ".inst 0x6e47ec56 // bfmmla v22.4s, v2.8h, v7.8h\n" + ".inst 0x6e44ec59 // bfmmla v25.4s, v2.8h, v4.8h\n" + ".inst 0x6e47ec7c // bfmmla v28.4s, v3.8h, v7.8h\n" + ".inst 0x6e44ec7f // bfmmla v31.4s, v3.8h, v4.8h\n" + "cbz x20, 6f\n" + "ldr q5, [x22, #0x0]\n" + "ldr q0, [%x[Apanel], #0x0]\n" + "ldr q1, [%x[Apanel], #0x10]\n" + "ldr q6, [x22, #0x10]\n" + "ldr q2, [%x[Apanel], #0x20]\n" + "ldr q3, [%x[Apanel], #0x30]\n" + "add %x[Apanel], %x[Apanel], #0x40\n" + "ldr q7, [x22, #0x20]\n" + "ldr q4, [x22, #0x30]\n" + ".inst 0x6e45ec08 // bfmmla v8.4s, v0.8h, v5.8h\n" + ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n" + ".inst 0x6e45ec2e // bfmmla v14.4s, v1.8h, v5.8h\n" + ".inst 0x6e46ec31 // bfmmla v17.4s, v1.8h, v6.8h\n" + ".inst 0x6e45ec54 // bfmmla v20.4s, v2.8h, v5.8h\n" + ".inst 0x6e46ec57 // bfmmla v23.4s, v2.8h, v6.8h\n" + ".inst 0x6e45ec7a // bfmmla v26.4s, v3.8h, v5.8h\n" + "ldr q5, [x22, #0x40]\n" + ".inst 0x6e46ec7d // bfmmla v29.4s, v3.8h, v6.8h\n" + "ldr q6, [x22, #0x50]\n" + ".inst 0x6e47ec09 // bfmmla v9.4s, v0.8h, v7.8h\n" + ".inst 0x6e44ec0c // bfmmla v12.4s, v0.8h, v4.8h\n" + ".inst 0x6e47ec2f // bfmmla v15.4s, v1.8h, v7.8h\n" + "add x22, x22, #0x60\n" + ".inst 0x6e44ec32 // bfmmla v18.4s, v1.8h, v4.8h\n" + ".inst 0x6e47ec55 // bfmmla v21.4s, v2.8h, v7.8h\n" + ".inst 0x6e44ec58 // bfmmla v24.4s, v2.8h, v4.8h\n" + ".inst 0x6e47ec7b // bfmmla v27.4s, v3.8h, v7.8h\n" + ".inst 0x6e44ec7e // bfmmla v30.4s, v3.8h, v4.8h\n" + ".inst 0x6e45ec0a // bfmmla v10.4s, v0.8h, v5.8h\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + ".inst 0x6e45ec30 // bfmmla v16.4s, v1.8h, v5.8h\n" + ".inst 0x6e46ec33 // bfmmla v19.4s, v1.8h, v6.8h\n" + ".inst 0x6e45ec56 // bfmmla v22.4s, v2.8h, v5.8h\n" + ".inst 0x6e46ec59 // bfmmla v25.4s, v2.8h, v6.8h\n" + ".inst 0x6e45ec7c // bfmmla v28.4s, v3.8h, v5.8h\n" + ".inst 0x6e46ec7f // bfmmla v31.4s, v3.8h, v6.8h\n" + "6:" // multiply loop done + "add x20, %x[args_ptr], %[offset_max]\n" + "uzp1 v7.2d, v8.2d, v11.2d\n" + "uzp2 v8.2d, v8.2d, v11.2d\n" + "ld1r { v1.4s }, [x20]\n" + "uzp1 v11.2d, v9.2d, v12.2d\n" + "uzp2 v9.2d, v9.2d, v12.2d\n" + "uzp1 v12.2d, v10.2d, v13.2d\n" + "uzp2 v10.2d, v10.2d, v13.2d\n" + "add x20, %x[args_ptr], %[offset_min]\n" + "ld1r { v0.4s }, [x20]\n" + "uzp1 v13.2d, v14.2d, v17.2d\n" + "uzp2 v14.2d, v14.2d, v17.2d\n" + "uzp1 v17.2d, v15.2d, v18.2d\n" + "uzp2 v15.2d, v15.2d, v18.2d\n" + "cmp x23, #0xc\n" + "uzp1 v18.2d, v16.2d, v19.2d\n" + "uzp2 v16.2d, v16.2d, v19.2d\n" + "uzp1 v19.2d, v20.2d, v23.2d\n" + "uzp2 v20.2d, v20.2d, v23.2d\n" + "uzp1 v23.2d, v21.2d, v24.2d\n" + "uzp2 v21.2d, v21.2d, v24.2d\n" + "uzp1 v24.2d, v22.2d, v25.2d\n" + "uzp2 v22.2d, v22.2d, v25.2d\n" + "uzp1 v25.2d, v26.2d, v29.2d\n" + "uzp2 v26.2d, v26.2d, v29.2d\n" + "uzp1 v29.2d, v27.2d, v30.2d\n" + "uzp2 v27.2d, v27.2d, v30.2d\n" + "uzp1 v30.2d, v28.2d, v31.2d\n" + "uzp2 v28.2d, v28.2d, v31.2d\n" + "fmin v7.4s, v7.4s, v1.4s\n" + "fmin v11.4s, v11.4s, v1.4s\n" + "fmin v12.4s, v12.4s, v1.4s\n" + "fmin v8.4s, v8.4s, v1.4s\n" + "fmin v9.4s, v9.4s, v1.4s\n" + "fmin v10.4s, v10.4s, v1.4s\n" + "fmin v13.4s, v13.4s, v1.4s\n" + "fmin v17.4s, v17.4s, v1.4s\n" + "fmin v18.4s, v18.4s, v1.4s\n" + "fmin v14.4s, v14.4s, v1.4s\n" + "fmin v15.4s, v15.4s, v1.4s\n" + "fmin v16.4s, v16.4s, v1.4s\n" + "fmin v19.4s, v19.4s, v1.4s\n" + "fmin v23.4s, v23.4s, v1.4s\n" + "fmin v24.4s, v24.4s, v1.4s\n" + "fmin v20.4s, v20.4s, v1.4s\n" + "fmin v21.4s, v21.4s, v1.4s\n" + "fmin v22.4s, v22.4s, v1.4s\n" + "fmin v25.4s, v25.4s, v1.4s\n" + "fmin v29.4s, v29.4s, v1.4s\n" + "fmin v30.4s, v30.4s, v1.4s\n" + "fmin v26.4s, v26.4s, v1.4s\n" + "fmin v27.4s, v27.4s, v1.4s\n" + "fmin v28.4s, v28.4s, v1.4s\n" + "fmax v7.4s, v7.4s, v0.4s\n" + "fmax v11.4s, v11.4s, v0.4s\n" + "fmax v12.4s, v12.4s, v0.4s\n" + "fmax v8.4s, v8.4s, v0.4s\n" + "fmax v9.4s, v9.4s, v0.4s\n" + "fmax v10.4s, v10.4s, v0.4s\n" + "fmax v13.4s, v13.4s, v0.4s\n" + "fmax v17.4s, v17.4s, v0.4s\n" + "fmax v18.4s, v18.4s, v0.4s\n" + "fmax v14.4s, v14.4s, v0.4s\n" + "fmax v15.4s, v15.4s, v0.4s\n" + "fmax v16.4s, v16.4s, v0.4s\n" + "fmax v19.4s, v19.4s, v0.4s\n" + "fmax v23.4s, v23.4s, v0.4s\n" + "fmax v24.4s, v24.4s, v0.4s\n" + "fmax v20.4s, v20.4s, v0.4s\n" + "fmax v21.4s, v21.4s, v0.4s\n" + "fmax v22.4s, v22.4s, v0.4s\n" + "fmax v25.4s, v25.4s, v0.4s\n" + "fmax v29.4s, v29.4s, v0.4s\n" + "fmax v30.4s, v30.4s, v0.4s\n" + "fmax v26.4s, v26.4s, v0.4s\n" + "fmax v27.4s, v27.4s, v0.4s\n" + "fmax v28.4s, v28.4s, v0.4s\n" + "blt 7f\n" + "str q26, [x24, #0x0]\n" + "str q27, [x24, #0x10]\n" + "str q28, [x24, #0x20]\n" + "add x24, x24, #0x30\n" + "str q25, [x9, #0x0]\n" + "str q29, [x9, #0x10]\n" + "str q30, [x9, #0x20]\n" + "add x9, x9, #0x30\n" + "str q20, [x25, #0x0]\n" + "str q21, [x25, #0x10]\n" + "str q22, [x25, #0x20]\n" + "add x25, x25, #0x30\n" + "str q19, [x11, #0x0]\n" + "str q23, [x11, #0x10]\n" + "str q24, [x11, #0x20]\n" + "add x11, x11, #0x30\n" + "str q14, [x26, #0x0]\n" + "str q15, [x26, #0x10]\n" + "str q16, [x26, #0x20]\n" + "add x26, x26, #0x30\n" + "str q13, [x10, #0x0]\n" + "str q17, [x10, #0x10]\n" + "str q18, [x10, #0x20]\n" + "add x10, x10, #0x30\n" + "str q8, [x27, #0x0]\n" + "str q9, [x27, #0x10]\n" + "str q10, [x27, #0x20]\n" + "add x27, x27, #0x30\n" + "str q7, [%x[Cpanel], #0x0]\n" + "str q11, [%x[Cpanel], #0x10]\n" + "str q12, [%x[Cpanel], #0x20]\n" + "add %x[Cpanel], %x[Cpanel], #0x30\n" + "b 14f\n" + "7:" // partial output + "tbz x23, #3, 9f\n" + "st1 { v26.4s }, [x24], #0x10\n" + "st1 { v27.4s }, [x24], #0x10\n" + "st1 { v25.4s }, [x9], #0x10\n" + "st1 { v29.4s }, [x9], #0x10\n" + "st1 { v20.4s }, [x25], #0x10\n" + "st1 { v21.4s }, [x25], #0x10\n" + "st1 { v19.4s }, [x11], #0x10\n" + "st1 { v23.4s }, [x11], #0x10\n" + "st1 { v14.4s }, [x26], #0x10\n" + "st1 { v15.4s }, [x26], #0x10\n" + "st1 { v13.4s }, [x10], #0x10\n" + "st1 { v17.4s }, [x10], #0x10\n" + "st1 { v8.4s }, [x27], #0x10\n" + "st1 { v9.4s }, [x27], #0x10\n" + "st1 { v7.4s }, [%x[Cpanel]], #0x10\n" + "st1 { v11.4s }, [%x[Cpanel]], #0x10\n" + "tbz x23, #1, 8f\n" + "str d28, [x24], #0x8\n" + "str d30, [x9], #0x8\n" + "str d22, [x25], #0x8\n" + "str d24, [x11], #0x8\n" + "str d16, [x26], #0x8\n" + "str d18, [x10], #0x8\n" + "str d10, [x27], #0x8\n" + "str d12, [%x[Cpanel]], #0x8\n" + "tbz x23, #0, 13f\n" + "st1 { v28.s }[2], [x24]\n" + "st1 { v30.s }[2], [x9]\n" + "st1 { v22.s }[2], [x25]\n" + "st1 { v24.s }[2], [x11]\n" + "st1 { v16.s }[2], [x26]\n" + "st1 { v18.s }[2], [x10]\n" + "st1 { v10.s }[2], [x27]\n" + "st1 { v12.s }[2], [%x[Cpanel]]\n" + "b 13f\n" + "8:" // partial result store: partial_1_8 + "tbz x23, #0, 13f\n" + "str s28, [x24, #0x0]\n" + "str s30, [x9, #0x0]\n" + "str s22, [x25, #0x0]\n" + "str s24, [x11, #0x0]\n" + "str s16, [x26, #0x0]\n" + "str s18, [x10, #0x0]\n" + "str s10, [x27, #0x0]\n" + "str s12, [%x[Cpanel], #0x0]\n" + "b 13f\n" + "9:" // partial result store: partial_4_0 + "tbz x23, #2, 11f\n" + "st1 { v26.4s }, [x24], #0x10\n" + "st1 { v25.4s }, [x9], #0x10\n" + "st1 { v20.4s }, [x25], #0x10\n" + "st1 { v19.4s }, [x11], #0x10\n" + "st1 { v14.4s }, [x26], #0x10\n" + "st1 { v13.4s }, [x10], #0x10\n" + "st1 { v8.4s }, [x27], #0x10\n" + "st1 { v7.4s }, [%x[Cpanel]], #0x10\n" + "tbz x23, #1, 10f\n" + "str d27, [x24], #0x8\n" + "str d29, [x9], #0x8\n" + "str d21, [x25], #0x8\n" + "str d23, [x11], #0x8\n" + "str d15, [x26], #0x8\n" + "str d17, [x10], #0x8\n" + "str d9, [x27], #0x8\n" + "str d11, [%x[Cpanel]], #0x8\n" + "tbz x23, #0, 13f\n" + "st1 { v27.s }[2], [x24]\n" + "st1 { v29.s }[2], [x9]\n" + "st1 { v21.s }[2], [x25]\n" + "st1 { v23.s }[2], [x11]\n" + "st1 { v15.s }[2], [x26]\n" + "st1 { v17.s }[2], [x10]\n" + "st1 { v9.s }[2], [x27]\n" + "st1 { v11.s }[2], [%x[Cpanel]]\n" + "b 13f\n" + "10:" // partial result store: partial_1_4 + "tbz x23, #0, 13f\n" + "str s27, [x24, #0x0]\n" + "str s29, [x9, #0x0]\n" + "str s21, [x25, #0x0]\n" + "str s23, [x11, #0x0]\n" + "str s15, [x26, #0x0]\n" + "str s17, [x10, #0x0]\n" + "str s9, [x27, #0x0]\n" + "str s11, [%x[Cpanel], #0x0]\n" + "b 13f\n" + "11:" // partial result store: partial_2_0 + "tbz x23, #1, 12f\n" + "str d26, [x24], #0x8\n" + "str d25, [x9], #0x8\n" + "str d20, [x25], #0x8\n" + "str d19, [x11], #0x8\n" + "str d14, [x26], #0x8\n" + "str d13, [x10], #0x8\n" + "str d8, [x27], #0x8\n" + "str d7, [%x[Cpanel]], #0x8\n" + "tbz x23, #0, 13f\n" + "st1 { v26.s }[2], [x24]\n" + "st1 { v25.s }[2], [x9]\n" + "st1 { v20.s }[2], [x25]\n" + "st1 { v19.s }[2], [x11]\n" + "st1 { v14.s }[2], [x26]\n" + "st1 { v13.s }[2], [x10]\n" + "st1 { v8.s }[2], [x27]\n" + "st1 { v7.s }[2], [%x[Cpanel]]\n" + "b 13f\n" + "12:" // partial result store: partial_1_0 + "str s26, [x24, #0x0]\n" + "str s25, [x9, #0x0]\n" + "str s20, [x25, #0x0]\n" + "str s19, [x11, #0x0]\n" + "str s14, [x26, #0x0]\n" + "str s13, [x10, #0x0]\n" + "str s8, [x27, #0x0]\n" + "str s7, [%x[Cpanel], #0x0]\n" + "13:" // partial result store: Done + "14:" // store done + "subs x23, x23, #0xc\n" + "bgt 3b\n" + "subs %x[M], %x[M], #0x8\n" + "mov %x[Cpanel], x28\n" + "bgt 1b\n" + : [Apanel] "+&r"(Apanel), [Cpanel] "+&r"(Cpanel), [M] "+&r"(M) + : [args_ptr] "r"(&ka), [ldc] "r"(ldc * sizeof(float)), [offset_max] "I"(offsetof(KernelArgs, maxval)), + [offset_min] "I"(offsetof(KernelArgs, minval)), [offsetof_Bpanel] "I"(offsetof(KernelArgs, Bpanel)), + [offsetof_K] "I"(offsetof(KernelArgs, K)), [offsetof_N] "I"(offsetof(KernelArgs, N)) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31", "x9", "x10", "x11", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.h b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.h new file mode 100644 index 0000000000000000000000000000000000000000..000d690a6cc4a829df719905e024d31961980a83 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.h @@ -0,0 +1,127 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) +#error This file must be compiled for AArch64, FEAT_BF16. +#else // Architectural features check. + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// -------------------------------------------------- + +/// Gets m step value. +/// +/// The starting row index must be divisible by `m_step`. +/// +/// @return The m step value. +size_t kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(void); + +/// Gets n step value. +/// +/// The starting column index must be divisible by `n_step`. +/// +/// @return The n step value. +size_t kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(void); + +/// Gets mr value. +/// +/// This is the packing parameter which must be used to pack the LHS matrix. +/// +/// @return The mr value. +size_t kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(void); + +/// Gets nr value. +/// +/// This is the packing parameter which must be used to pack the RHS matrix. +/// +/// @return The nr value. +size_t kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(void); + +/// Gets kr value. +/// +/// This is the packing parameter which must be used to pack the LHS & RHS matrices. +/// +/// @return The kr value. +size_t kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(void); + +/// Gets sr value. +/// +/// This is the packing parameter which must be used to pack the RHS matrix. +/// +/// @return The sr value. +size_t kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(void); + +/// Gets the offset in bytes to the data element in the LHS matrix buffer. +/// +/// @param[in] m_idx Row index. +/// @param[in] stride Row stride in bytes. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_lhs_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(size_t m_idx, size_t stride); + +/// Gets the offset in bytes to the data element in the packed RHS matrix buffer. +/// +/// @param[in] n_idx Row index. +/// @param[in] k Number of columns. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(size_t n_idx, size_t k); + +/// Gets the offset in bytes to the data element in the destination matrix buffer. +/// +/// @param[in] m_idx Row index. +/// @param[in] n_idx Column index. +/// @param[in] stride Row stride in bytes. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla( + size_t m_idx, size_t n_idx, size_t stride); + +/// Gets the size in bytes of the destination matrix buffer. +/// +/// @param[in] m Number of rows. +/// @param[in] n Number of columns. +/// +/// @return The size in bytes of the destination matrix buffer. +size_t kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla(size_t m, size_t n); + +/// Runs the matrix multiplication microkernel followed by a clamp operation. +/// +/// The pointer of each buffers (LHS, packed RHS and output) needs to be added with offset +/// calculated using the following functions: +/// +/// * Packed LHS: @ref kai_get_lhs_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla. +/// * Packed RHS: @ref kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla. +/// * Output: @ref kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla. +/// +/// @param[in] m Number of output rows to be computed. +/// @param[in] n Number of output columns to be computed. +/// @param[in] k Common dimension of the LHS and RHS operand. +/// @param[in] lhs LHS matrix buffer. +/// @param[in] rhs_packed Packed RHS buffer. +/// @param[out] dst Output matrix buffer. +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(float) +/// @param[in] clamp_min Minimum value to clamp the final result. +/// @param[in] clamp_max Maximum value to clamp the final result. +void kai_run_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla( + size_t m, size_t n, size_t k, // + const void* lhs_packed, // + const void* rhs_packed, // + void* dst, size_t dst_stride_row, size_t dst_stride_col, // + float clamp_min, float clamp_max); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p_interface.h b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p_interface.h new file mode 100644 index 0000000000000000000000000000000000000000..8eb1969acf0692e4d16948b23d080c43042af231 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p_interface.h @@ -0,0 +1,57 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) +#error This file must be compiled for AArch64, FEAT_BF16. +#else // Architectural features check. + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// All micro-kernels variants of the same type share the same interfaces +// In this case, the micro-kernel type is: matmul_clamp_f32_bf16p_bf16p + +/// Micro-kernel helper functions ("get" methods) +typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_m_step_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_n_step_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_mr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_nr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_kr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_sr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_lhs_offset_func_t)(size_t m_idx, size_t lhs_stride); +typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_rhs_packed_offset_func_t)(size_t n_idx, size_t k); +typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_dst_offset_func_t)(size_t m_idx, size_t n_idx, size_t dst_stride); +typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_dst_size_func_t)(size_t m, size_t n); + +/// Micro-kernel core function ("run" method) +typedef void (*kai_matmul_clamp_f32_bf16p_bf16p_run_matmul_func_t)( + size_t m, size_t n, size_t k, const void* lhs, const void* rhs_packed, void* dst, size_t dst_stride_row, + size_t dst_stride_col, float scalar_min, float scalar_max); + +/// Micro-kernel interface +struct kai_matmul_clamp_f32_bf16p_bf16p_ukernel { + kai_matmul_clamp_f32_bf16p_bf16p_get_m_step_func_t get_m_step; + kai_matmul_clamp_f32_bf16p_bf16p_get_n_step_func_t get_n_step; + kai_matmul_clamp_f32_bf16p_bf16p_get_mr_func_t get_mr; + kai_matmul_clamp_f32_bf16p_bf16p_get_nr_func_t get_nr; + kai_matmul_clamp_f32_bf16p_bf16p_get_nr_func_t get_kr; + kai_matmul_clamp_f32_bf16p_bf16p_get_sr_func_t get_sr; + kai_matmul_clamp_f32_bf16p_bf16p_get_lhs_offset_func_t get_lhs_packed_offset; + kai_matmul_clamp_f32_bf16p_bf16p_get_rhs_packed_offset_func_t get_rhs_packed_offset; + kai_matmul_clamp_f32_bf16p_bf16p_get_dst_offset_func_t get_dst_offset; + kai_matmul_clamp_f32_bf16p_bf16p_get_dst_size_func_t get_dst_size; + kai_matmul_clamp_f32_bf16p_bf16p_run_matmul_func_t run_matmul; +}; + +#ifdef __cplusplus +} +#endif + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_f32_bf16p/kai_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla.c b/kai/ukernels/matmul/matmul_clamp_f32_f32_bf16p/kai_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla.c new file mode 100644 index 0000000000000000000000000000000000000000..14c0ebc0f08018e325ea896c5e96791412fe3af9 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_f32_bf16p/kai_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla.c @@ -0,0 +1,2454 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) +#error This file must be compiled for AArch64, FEAT_BF16. +#else // Architectural features check. + +#include "kai_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla.h" + +#include +#include +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_mr = 4; +static const size_t kai_nr = 24; +static const size_t kai_kr = 4; +static const size_t kai_sr = 1; + +size_t kai_get_m_step_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla(void) { + return kai_mr; +} + +size_t kai_get_n_step_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla(void) { + return kai_nr; +} + +size_t kai_get_mr_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla(void) { + return kai_sr; +} + +size_t kai_get_lhs_offset_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla(size_t m_idx, size_t stride) { + KAI_ASSUME(m_idx % kai_mr == 0); + + return m_idx * stride; +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla(size_t n_idx, size_t k) { + KAI_ASSUME(n_idx % kai_nr == 0); + + return n_idx * (sizeof(float) + kai_roundup(k, kai_kr) * sizeof(bfloat16_t)); +} + +size_t kai_get_dst_offset_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla( + size_t m_idx, size_t n_idx, size_t stride) { + KAI_ASSUME(m_idx % kai_mr == 0); + KAI_ASSUME(n_idx % kai_nr == 0); + + return m_idx * stride + n_idx * sizeof(float); +} + +size_t kai_get_dst_size_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla(size_t m, size_t n) { + return m * n * sizeof(float); +} + +void kai_run_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla( + size_t m, size_t n, size_t k, // + const void* lhs, size_t lhs_stride, // + const void* rhs_packed, // + void* dst, size_t dst_stride_row, size_t dst_stride_col, // + float clamp_min, float clamp_max) { + KAI_ASSERT(dst_stride_col == sizeof(float)); + + typedef struct { + float maxval; + float minval; + unsigned int num_strings; + const unsigned int* string_lengths; + size_t N; + const void* B_ptr; + size_t output_offset; + size_t input_initial_col; + size_t input_offset; + void* output_ptr; + const void* bias; + } KernelArgs; + + KernelArgs ka; + + unsigned long flags = 0; + + unsigned int string_length = k; + ka.num_strings = 1; + ka.string_lengths = &string_length; + ka.N = n; + ka.B_ptr = rhs_packed; + ka.bias = NULL; + + // Direct input. + const void* input_ptr = lhs; + ka.input_offset = lhs_stride / sizeof(float); + ka.input_initial_col = 0; + + // Direct output. + ka.output_ptr = dst; + ka.output_offset = dst_stride_row / sizeof(float); + + // Clamping output. + flags |= 0x2; + ka.maxval = clamp_max; + ka.minval = clamp_min; + + __asm__ __volatile__( + "1:" // Row loop + "cmp %x[m], #0x4\n" + "bge 130f\n" + "cmp %x[m], #0x2\n" + "bgt 87f\n" + "beq 44f\n" + "ldr x9, [%x[args_ptr], %[offsetof_N]]\n" + "ldr x28, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "ldr x27, [%x[args_ptr], %[offsetof_output_ptr]]\n" + "2:" // Height 1: Column loop + "cbz x28, 3f\n" + "ldr q8, [x28, #0x0]\n" + "ldr q9, [x28, #0x10]\n" + "ldr q10, [x28, #0x20]\n" + "ldr q11, [x28, #0x30]\n" + "ldr q12, [x28, #0x40]\n" + "ldr q13, [x28, #0x50]\n" + "add x28, x28, #0x60\n" + "zip2 v14.2d, v8.2d, v8.2d\n" + "zip1 v8.2d, v8.2d, v8.2d\n" + "zip2 v15.2d, v9.2d, v9.2d\n" + "zip1 v9.2d, v9.2d, v9.2d\n" + "zip2 v16.2d, v10.2d, v10.2d\n" + "zip1 v10.2d, v10.2d, v10.2d\n" + "zip2 v17.2d, v11.2d, v11.2d\n" + "zip1 v11.2d, v11.2d, v11.2d\n" + "zip2 v18.2d, v12.2d, v12.2d\n" + "zip1 v12.2d, v12.2d, v12.2d\n" + "zip2 v19.2d, v13.2d, v13.2d\n" + "zip1 v13.2d, v13.2d, v13.2d\n" + "b 19f\n" + "3:" // Height 1: no bias + "tbz %x[flags], #0, 18f\n" + "cmp x9, #0x18\n" + "bge 16f\n" + "tbz x9, #4, 7f\n" + "ld1 { v9.4s }, [x27], #0x10\n" + "ld1 { v10.4s }, [x27], #0x10\n" + "ld1 { v11.4s }, [x27], #0x10\n" + "ld1 { v12.4s }, [x27], #0x10\n" + "tbz x9, #2, 5f\n" + "ld1 { v13.4s }, [x27], #0x10\n" + "tbz x9, #1, 4f\n" + "ldr d20, [x27], #0x8\n" + "mov x20, #0x58\n" + "tbz x9, #0, 15f\n" + "ld1 { v20.s }[2], [x27]\n" + "b 15f\n" + "4:" // Height 1: Partial accumulate: partial_1_20 + "mov x20, #0x50\n" + "tbz x9, #0, 15f\n" + "ldr s20, [x27, #0x0]\n" + "b 15f\n" + "5:" // Height 1: Partial accumulate: partial_2_16 + "tbz x9, #1, 6f\n" + "ldr d13, [x27], #0x8\n" + "mov x20, #0x48\n" + "tbz x9, #0, 15f\n" + "ld1 { v13.s }[2], [x27]\n" + "b 15f\n" + "6:" // Height 1: Partial accumulate: partial_1_16 + "mov x20, #0x40\n" + "tbz x9, #0, 15f\n" + "ldr s13, [x27, #0x0]\n" + "b 15f\n" + "7:" // Height 1: Partial accumulate: partial_8_0 + "tbz x9, #3, 11f\n" + "ld1 { v9.4s }, [x27], #0x10\n" + "ld1 { v10.4s }, [x27], #0x10\n" + "tbz x9, #2, 9f\n" + "ld1 { v11.4s }, [x27], #0x10\n" + "tbz x9, #1, 8f\n" + "ldr d12, [x27], #0x8\n" + "mov x20, #0x38\n" + "tbz x9, #0, 15f\n" + "ld1 { v12.s }[2], [x27]\n" + "b 15f\n" + "8:" // Height 1: Partial accumulate: partial_1_12 + "mov x20, #0x30\n" + "tbz x9, #0, 15f\n" + "ldr s12, [x27, #0x0]\n" + "b 15f\n" + "9:" // Height 1: Partial accumulate: partial_2_8 + "tbz x9, #1, 10f\n" + "ldr d11, [x27], #0x8\n" + "mov x20, #0x28\n" + "tbz x9, #0, 15f\n" + "ld1 { v11.s }[2], [x27]\n" + "b 15f\n" + "10:" // Height 1: Partial accumulate: partial_1_8 + "mov x20, #0x20\n" + "tbz x9, #0, 15f\n" + "ldr s11, [x27, #0x0]\n" + "b 15f\n" + "11:" // Height 1: Partial accumulate: partial_4_0 + "tbz x9, #2, 13f\n" + "ld1 { v9.4s }, [x27], #0x10\n" + "tbz x9, #1, 12f\n" + "ldr d10, [x27], #0x8\n" + "mov x20, #0x18\n" + "tbz x9, #0, 15f\n" + "ld1 { v10.s }[2], [x27]\n" + "b 15f\n" + "12:" // Height 1: Partial accumulate: partial_1_4 + "mov x20, #0x10\n" + "tbz x9, #0, 15f\n" + "ldr s10, [x27, #0x0]\n" + "b 15f\n" + "13:" // Height 1: Partial accumulate: partial_2_0 + "tbz x9, #1, 14f\n" + "ldr d9, [x27], #0x8\n" + "mov x20, #0x8\n" + "tbz x9, #0, 15f\n" + "ld1 { v9.s }[2], [x27]\n" + "b 15f\n" + "14:" // Height 1: Partial accumulate: partial_1_0 + "ldr s9, [x27, #0x0]\n" + "mov x20, #0x0\n" + "15:" // Height 1: Partial accumulate: Done + "sub x27, x27, x20\n" + "b 17f\n" + "16:" // Height 1: full accumulate + "ldr q9, [x27, #0x0]\n" + "ldr q10, [x27, #0x10]\n" + "ldr q11, [x27, #0x20]\n" + "ldr q12, [x27, #0x30]\n" + "ldr q13, [x27, #0x40]\n" + "ldr q20, [x27, #0x50]\n" + "17:" // Height 1: MMLA fixup + "zip1 v8.2d, v9.2d, v14.2d\n" + "zip2 v14.2d, v9.2d, v14.2d\n" + "zip1 v9.2d, v10.2d, v15.2d\n" + "zip2 v15.2d, v10.2d, v15.2d\n" + "zip1 v10.2d, v11.2d, v16.2d\n" + "zip2 v16.2d, v11.2d, v16.2d\n" + "zip1 v11.2d, v12.2d, v17.2d\n" + "zip2 v17.2d, v12.2d, v17.2d\n" + "zip1 v12.2d, v13.2d, v18.2d\n" + "zip2 v18.2d, v13.2d, v18.2d\n" + "zip1 v13.2d, v20.2d, v19.2d\n" + "zip2 v19.2d, v20.2d, v19.2d\n" + "b 19f\n" + "18:" // Height 1: no accumulate + "movi v8.16b, #0x0\n" + "movi v9.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "movi v15.16b, #0x0\n" + "movi v16.16b, #0x0\n" + "movi v17.16b, #0x0\n" + "movi v18.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "19:" // Height 1: setup done + "mov x26, #0x0\n" + "20:" // Height 1: String loop + "ldr x20, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr x21, [%x[args_ptr], %[offsetof_input_offset]]\n" + "ldr w25, [x20, x26, LSL #0x2]\n" + "tbz %x[flags], #3, 21f\n" + "ldr x20, [%x[input_ptr], x26, LSL #0x3]\n" + "add x20, x20, x21, LSL #3\n" + "ldr x24, [x20, #0x0]\n" + "cbnz x26, 22f\n" + "ldr x20, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x24, x24, x20, LSL #2\n" + "b 22f\n" + "21:" // Height 1: setup direct input + "mov x24, %x[input_ptr]\n" + "22:" // Height 1: input setup done + "cmp x25, #0x4\n" + "blt 25f\n" + "ld1 { v0.4s }, [x24], #0x10\n" + "ldr q4, [x28, #0x0]\n" + "cmp x25, #0x8\n" + "ldr q5, [x28, #0x10]\n" + "ldr q6, [x28, #0x20]\n" + "ldr q7, [x28, #0x30]\n" + "blt 24f\n" + "23:" // Height 1: Multiply loop: Main loop head + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + "sub x25, x25, #0x4\n" + "prfm pldl1keep, [x24, #0x80]\n" + "cmp x25, #0x8\n" + ".inst 0x6e44ec08 // bfmmla v8.4s, v0.8h, v4.8h\n" + "ldr q4, [x28, #0x40]\n" + ".inst 0x6e45ec0e // bfmmla v14.4s, v0.8h, v5.8h\n" + "ldr q5, [x28, #0x50]\n" + ".inst 0x6e46ec09 // bfmmla v9.4s, v0.8h, v6.8h\n" + "ldr q6, [x28, #0x60]\n" + ".inst 0x6e47ec0f // bfmmla v15.4s, v0.8h, v7.8h\n" + "ldr q7, [x28, #0x70]\n" + ".inst 0x6e44ec0a // bfmmla v10.4s, v0.8h, v4.8h\n" + "ldr q4, [x28, #0x80]\n" + ".inst 0x6e45ec10 // bfmmla v16.4s, v0.8h, v5.8h\n" + "ldr q5, [x28, #0x90]\n" + ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n" + "ldr q6, [x28, #0xa0]\n" + ".inst 0x6e47ec11 // bfmmla v17.4s, v0.8h, v7.8h\n" + "ldr q7, [x28, #0xb0]\n" + "add x28, x28, #0xc0\n" + ".inst 0x6e44ec0c // bfmmla v12.4s, v0.8h, v4.8h\n" + "ldr q4, [x28, #0x0]\n" + ".inst 0x6e45ec12 // bfmmla v18.4s, v0.8h, v5.8h\n" + "ldr q5, [x28, #0x10]\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + "ldr q6, [x28, #0x20]\n" + ".inst 0x6e47ec13 // bfmmla v19.4s, v0.8h, v7.8h\n" + "ld1 { v0.4s }, [x24], #0x10\n" + "ldr q7, [x28, #0x30]\n" + "bge 23b\n" + "24:" // Height 1: Multiply loop: Single iteration only + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + "sub x25, x25, #0x4\n" + "prfm pldl1keep, [x24, #0x80]\n" + ".inst 0x6e44ec08 // bfmmla v8.4s, v0.8h, v4.8h\n" + "ldr q4, [x28, #0x40]\n" + ".inst 0x6e45ec0e // bfmmla v14.4s, v0.8h, v5.8h\n" + "ldr q5, [x28, #0x50]\n" + ".inst 0x6e46ec09 // bfmmla v9.4s, v0.8h, v6.8h\n" + "ldr q6, [x28, #0x60]\n" + ".inst 0x6e47ec0f // bfmmla v15.4s, v0.8h, v7.8h\n" + "ldr q7, [x28, #0x70]\n" + ".inst 0x6e44ec0a // bfmmla v10.4s, v0.8h, v4.8h\n" + "ldr q4, [x28, #0x80]\n" + ".inst 0x6e45ec10 // bfmmla v16.4s, v0.8h, v5.8h\n" + "ldr q5, [x28, #0x90]\n" + ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n" + "ldr q6, [x28, #0xa0]\n" + ".inst 0x6e47ec11 // bfmmla v17.4s, v0.8h, v7.8h\n" + "ldr q7, [x28, #0xb0]\n" + "add x28, x28, #0xc0\n" + ".inst 0x6e44ec0c // bfmmla v12.4s, v0.8h, v4.8h\n" + ".inst 0x6e45ec12 // bfmmla v18.4s, v0.8h, v5.8h\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + ".inst 0x6e47ec13 // bfmmla v19.4s, v0.8h, v7.8h\n" + "25:" // Height 1: Multiply loop: Main loop skip + "cbz x25, 28f\n" + "cbz x25, 28f\n" + "tbz x25, #1, 26f\n" + "ldr d0, [x24], #0x8\n" + "tbz x25, #0, 27f\n" + "ld1 { v0.s }[2], [x24]\n" + "b 27f\n" + "26:" // Height 1: Multiply loop: Ragged operand read: partial_1_0 + "ldr s0, [x24, #0x0]\n" + "27:" // Height 1: Multiply loop: Ragged operand read: Done + "ldr q4, [x28, #0x0]\n" + "ldr q5, [x28, #0x10]\n" + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + "ldr q6, [x28, #0x20]\n" + "ldr q7, [x28, #0x30]\n" + ".inst 0x6e44ec08 // bfmmla v8.4s, v0.8h, v4.8h\n" + "ldr q4, [x28, #0x40]\n" + ".inst 0x6e45ec0e // bfmmla v14.4s, v0.8h, v5.8h\n" + "ldr q5, [x28, #0x50]\n" + ".inst 0x6e46ec09 // bfmmla v9.4s, v0.8h, v6.8h\n" + "ldr q6, [x28, #0x60]\n" + ".inst 0x6e47ec0f // bfmmla v15.4s, v0.8h, v7.8h\n" + "ldr q7, [x28, #0x70]\n" + ".inst 0x6e44ec0a // bfmmla v10.4s, v0.8h, v4.8h\n" + "ldr q4, [x28, #0x80]\n" + ".inst 0x6e45ec10 // bfmmla v16.4s, v0.8h, v5.8h\n" + "ldr q5, [x28, #0x90]\n" + ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n" + "ldr q6, [x28, #0xa0]\n" + ".inst 0x6e47ec11 // bfmmla v17.4s, v0.8h, v7.8h\n" + "ldr q7, [x28, #0xb0]\n" + "add x28, x28, #0xc0\n" + ".inst 0x6e44ec0c // bfmmla v12.4s, v0.8h, v4.8h\n" + ".inst 0x6e45ec12 // bfmmla v18.4s, v0.8h, v5.8h\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + ".inst 0x6e47ec13 // bfmmla v19.4s, v0.8h, v7.8h\n" + "28:" // Height 1: Multiply loop: No odd multiplies + "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x26, x26, #0x1\n" + "cmp x26, x20\n" + "bne 20b\n" + "uzp1 v8.2d, v8.2d, v14.2d\n" + "uzp1 v9.2d, v9.2d, v15.2d\n" + "prfm pstl1keep, [x27, #0x0]\n" + "uzp1 v10.2d, v10.2d, v16.2d\n" + "uzp1 v11.2d, v11.2d, v17.2d\n" + "uzp1 v12.2d, v12.2d, v18.2d\n" + "uzp1 v13.2d, v13.2d, v19.2d\n" + "tbz %x[flags], #1, 29f\n" + "add x21, %x[args_ptr], %[offset_max]\n" + "add x20, %x[args_ptr], %[offset_min]\n" + "ld1r { v22.4s }, [x21]\n" + "ld1r { v21.4s }, [x20]\n" + "fmin v8.4s, v8.4s, v22.4s\n" + "fmin v9.4s, v9.4s, v22.4s\n" + "fmin v10.4s, v10.4s, v22.4s\n" + "fmin v11.4s, v11.4s, v22.4s\n" + "fmin v12.4s, v12.4s, v22.4s\n" + "fmin v13.4s, v13.4s, v22.4s\n" + "fmax v8.4s, v8.4s, v21.4s\n" + "fmax v9.4s, v9.4s, v21.4s\n" + "fmax v10.4s, v10.4s, v21.4s\n" + "fmax v11.4s, v11.4s, v21.4s\n" + "fmax v12.4s, v12.4s, v21.4s\n" + "fmax v13.4s, v13.4s, v21.4s\n" + "29:" // Height 1: No activation + "cmp x9, #0x18\n" + "bge 42f\n" + "tbz x9, #4, 33f\n" + "st1 { v8.4s }, [x27], #0x10\n" + "st1 { v9.4s }, [x27], #0x10\n" + "st1 { v10.4s }, [x27], #0x10\n" + "st1 { v11.4s }, [x27], #0x10\n" + "tbz x9, #2, 31f\n" + "st1 { v12.4s }, [x27], #0x10\n" + "tbz x9, #1, 30f\n" + "str d13, [x27], #0x8\n" + "tbz x9, #0, 41f\n" + "st1 { v13.s }[2], [x27]\n" + "b 41f\n" + "30:" // Height 1: Partial direct writeback: partial_1_20 + "tbz x9, #0, 41f\n" + "str s13, [x27, #0x0]\n" + "b 41f\n" + "31:" // Height 1: Partial direct writeback: partial_2_16 + "tbz x9, #1, 32f\n" + "str d12, [x27], #0x8\n" + "tbz x9, #0, 41f\n" + "st1 { v12.s }[2], [x27]\n" + "b 41f\n" + "32:" // Height 1: Partial direct writeback: partial_1_16 + "tbz x9, #0, 41f\n" + "str s12, [x27, #0x0]\n" + "b 41f\n" + "33:" // Height 1: Partial direct writeback: partial_8_0 + "tbz x9, #3, 37f\n" + "st1 { v8.4s }, [x27], #0x10\n" + "st1 { v9.4s }, [x27], #0x10\n" + "tbz x9, #2, 35f\n" + "st1 { v10.4s }, [x27], #0x10\n" + "tbz x9, #1, 34f\n" + "str d11, [x27], #0x8\n" + "tbz x9, #0, 41f\n" + "st1 { v11.s }[2], [x27]\n" + "b 41f\n" + "34:" // Height 1: Partial direct writeback: partial_1_12 + "tbz x9, #0, 41f\n" + "str s11, [x27, #0x0]\n" + "b 41f\n" + "35:" // Height 1: Partial direct writeback: partial_2_8 + "tbz x9, #1, 36f\n" + "str d10, [x27], #0x8\n" + "tbz x9, #0, 41f\n" + "st1 { v10.s }[2], [x27]\n" + "b 41f\n" + "36:" // Height 1: Partial direct writeback: partial_1_8 + "tbz x9, #0, 41f\n" + "str s10, [x27, #0x0]\n" + "b 41f\n" + "37:" // Height 1: Partial direct writeback: partial_4_0 + "tbz x9, #2, 39f\n" + "st1 { v8.4s }, [x27], #0x10\n" + "tbz x9, #1, 38f\n" + "str d9, [x27], #0x8\n" + "tbz x9, #0, 41f\n" + "st1 { v9.s }[2], [x27]\n" + "b 41f\n" + "38:" // Height 1: Partial direct writeback: partial_1_4 + "tbz x9, #0, 41f\n" + "str s9, [x27, #0x0]\n" + "b 41f\n" + "39:" // Height 1: Partial direct writeback: partial_2_0 + "tbz x9, #1, 40f\n" + "str d8, [x27], #0x8\n" + "tbz x9, #0, 41f\n" + "st1 { v8.s }[2], [x27]\n" + "b 41f\n" + "40:" // Height 1: Partial direct writeback: partial_1_0 + "str s8, [x27, #0x0]\n" + "41:" // Height 1: Partial direct writeback: Done + "b 43f\n" + "42:" // Height 1: Full writeback + "str q8, [x27, #0x0]\n" + "str q9, [x27, #0x10]\n" + "str q10, [x27, #0x20]\n" + "str q11, [x27, #0x30]\n" + "str q12, [x27, #0x40]\n" + "str q13, [x27, #0x50]\n" + "add x27, x27, #0x60\n" + "43:" // Height 1: Writeback done + "subs x9, x9, #0x18\n" + "bgt 2b\n" + "b 174f\n" + "44:" // Height 2 + "ldr x9, [%x[args_ptr], %[offsetof_N]]\n" + "ldr x28, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "ldr x27, [%x[args_ptr], %[offsetof_output_ptr]]\n" + "45:" // Height 2: Column loop + "cbz x28, 46f\n" + "ldr q8, [x28, #0x0]\n" + "ldr q9, [x28, #0x10]\n" + "ldr q10, [x28, #0x20]\n" + "ldr q11, [x28, #0x30]\n" + "ldr q12, [x28, #0x40]\n" + "ldr q13, [x28, #0x50]\n" + "add x28, x28, #0x60\n" + "zip2 v14.2d, v8.2d, v8.2d\n" + "zip1 v8.2d, v8.2d, v8.2d\n" + "zip2 v15.2d, v9.2d, v9.2d\n" + "zip1 v9.2d, v9.2d, v9.2d\n" + "zip2 v16.2d, v10.2d, v10.2d\n" + "zip1 v10.2d, v10.2d, v10.2d\n" + "zip2 v17.2d, v11.2d, v11.2d\n" + "zip1 v11.2d, v11.2d, v11.2d\n" + "zip2 v18.2d, v12.2d, v12.2d\n" + "zip1 v12.2d, v12.2d, v12.2d\n" + "zip2 v19.2d, v13.2d, v13.2d\n" + "zip1 v13.2d, v13.2d, v13.2d\n" + "b 62f\n" + "46:" // Height 2: no bias + "tbz %x[flags], #0, 61f\n" + "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n" + "cmp x9, #0x18\n" + "add x24, x27, x20, LSL #2\n" + "bge 59f\n" + "tbz x9, #4, 50f\n" + "ld1 { v9.4s }, [x27], #0x10\n" + "ld1 { v14.4s }, [x24], #0x10\n" + "ld1 { v10.4s }, [x27], #0x10\n" + "ld1 { v15.4s }, [x24], #0x10\n" + "ld1 { v11.4s }, [x27], #0x10\n" + "ld1 { v16.4s }, [x24], #0x10\n" + "ld1 { v12.4s }, [x27], #0x10\n" + "ld1 { v17.4s }, [x24], #0x10\n" + "tbz x9, #2, 48f\n" + "ld1 { v13.4s }, [x27], #0x10\n" + "ld1 { v18.4s }, [x24], #0x10\n" + "tbz x9, #1, 47f\n" + "ldr d20, [x27], #0x8\n" + "ldr d19, [x24], #0x8\n" + "mov x20, #0x58\n" + "tbz x9, #0, 58f\n" + "ld1 { v20.s }[2], [x27]\n" + "ld1 { v19.s }[2], [x24]\n" + "b 58f\n" + "47:" // Height 2: Partial accumulate: partial_1_20 + "mov x20, #0x50\n" + "tbz x9, #0, 58f\n" + "ldr s20, [x27, #0x0]\n" + "ldr s19, [x24, #0x0]\n" + "b 58f\n" + "48:" // Height 2: Partial accumulate: partial_2_16 + "tbz x9, #1, 49f\n" + "ldr d13, [x27], #0x8\n" + "ldr d18, [x24], #0x8\n" + "mov x20, #0x48\n" + "tbz x9, #0, 58f\n" + "ld1 { v13.s }[2], [x27]\n" + "ld1 { v18.s }[2], [x24]\n" + "b 58f\n" + "49:" // Height 2: Partial accumulate: partial_1_16 + "mov x20, #0x40\n" + "tbz x9, #0, 58f\n" + "ldr s13, [x27, #0x0]\n" + "ldr s18, [x24, #0x0]\n" + "b 58f\n" + "50:" // Height 2: Partial accumulate: partial_8_0 + "tbz x9, #3, 54f\n" + "ld1 { v9.4s }, [x27], #0x10\n" + "ld1 { v14.4s }, [x24], #0x10\n" + "ld1 { v10.4s }, [x27], #0x10\n" + "ld1 { v15.4s }, [x24], #0x10\n" + "tbz x9, #2, 52f\n" + "ld1 { v11.4s }, [x27], #0x10\n" + "ld1 { v16.4s }, [x24], #0x10\n" + "tbz x9, #1, 51f\n" + "ldr d12, [x27], #0x8\n" + "ldr d17, [x24], #0x8\n" + "mov x20, #0x38\n" + "tbz x9, #0, 58f\n" + "ld1 { v12.s }[2], [x27]\n" + "ld1 { v17.s }[2], [x24]\n" + "b 58f\n" + "51:" // Height 2: Partial accumulate: partial_1_12 + "mov x20, #0x30\n" + "tbz x9, #0, 58f\n" + "ldr s12, [x27, #0x0]\n" + "ldr s17, [x24, #0x0]\n" + "b 58f\n" + "52:" // Height 2: Partial accumulate: partial_2_8 + "tbz x9, #1, 53f\n" + "ldr d11, [x27], #0x8\n" + "ldr d16, [x24], #0x8\n" + "mov x20, #0x28\n" + "tbz x9, #0, 58f\n" + "ld1 { v11.s }[2], [x27]\n" + "ld1 { v16.s }[2], [x24]\n" + "b 58f\n" + "53:" // Height 2: Partial accumulate: partial_1_8 + "mov x20, #0x20\n" + "tbz x9, #0, 58f\n" + "ldr s11, [x27, #0x0]\n" + "ldr s16, [x24, #0x0]\n" + "b 58f\n" + "54:" // Height 2: Partial accumulate: partial_4_0 + "tbz x9, #2, 56f\n" + "ld1 { v9.4s }, [x27], #0x10\n" + "ld1 { v14.4s }, [x24], #0x10\n" + "tbz x9, #1, 55f\n" + "ldr d10, [x27], #0x8\n" + "ldr d15, [x24], #0x8\n" + "mov x20, #0x18\n" + "tbz x9, #0, 58f\n" + "ld1 { v10.s }[2], [x27]\n" + "ld1 { v15.s }[2], [x24]\n" + "b 58f\n" + "55:" // Height 2: Partial accumulate: partial_1_4 + "mov x20, #0x10\n" + "tbz x9, #0, 58f\n" + "ldr s10, [x27, #0x0]\n" + "ldr s15, [x24, #0x0]\n" + "b 58f\n" + "56:" // Height 2: Partial accumulate: partial_2_0 + "tbz x9, #1, 57f\n" + "ldr d9, [x27], #0x8\n" + "ldr d14, [x24], #0x8\n" + "mov x20, #0x8\n" + "tbz x9, #0, 58f\n" + "ld1 { v9.s }[2], [x27]\n" + "ld1 { v14.s }[2], [x24]\n" + "b 58f\n" + "57:" // Height 2: Partial accumulate: partial_1_0 + "ldr s9, [x27, #0x0]\n" + "ldr s14, [x24, #0x0]\n" + "mov x20, #0x0\n" + "58:" // Height 2: Partial accumulate: Done + "sub x27, x27, x20\n" + "b 60f\n" + "59:" // Height 2: full accumulate + "ldr q9, [x27, #0x0]\n" + "ldr q10, [x27, #0x10]\n" + "ldr q11, [x27, #0x20]\n" + "ldr q12, [x27, #0x30]\n" + "ldr q13, [x27, #0x40]\n" + "ldr q20, [x27, #0x50]\n" + "ldr q14, [x24, #0x0]\n" + "ldr q15, [x24, #0x10]\n" + "ldr q16, [x24, #0x20]\n" + "ldr q17, [x24, #0x30]\n" + "ldr q18, [x24, #0x40]\n" + "ldr q19, [x24, #0x50]\n" + "60:" // Height 2: MMLA fixup + "zip1 v8.2d, v9.2d, v14.2d\n" + "zip2 v14.2d, v9.2d, v14.2d\n" + "zip1 v9.2d, v10.2d, v15.2d\n" + "zip2 v15.2d, v10.2d, v15.2d\n" + "zip1 v10.2d, v11.2d, v16.2d\n" + "zip2 v16.2d, v11.2d, v16.2d\n" + "zip1 v11.2d, v12.2d, v17.2d\n" + "zip2 v17.2d, v12.2d, v17.2d\n" + "zip1 v12.2d, v13.2d, v18.2d\n" + "zip2 v18.2d, v13.2d, v18.2d\n" + "zip1 v13.2d, v20.2d, v19.2d\n" + "zip2 v19.2d, v20.2d, v19.2d\n" + "b 62f\n" + "61:" // Height 2: no accumulate + "movi v8.16b, #0x0\n" + "movi v9.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "movi v15.16b, #0x0\n" + "movi v16.16b, #0x0\n" + "movi v17.16b, #0x0\n" + "movi v18.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "62:" // Height 2: setup done + "mov x26, #0x0\n" + "63:" // Height 2: String loop + "ldr x20, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr x21, [%x[args_ptr], %[offsetof_input_offset]]\n" + "ldr w25, [x20, x26, LSL #0x2]\n" + "tbz %x[flags], #3, 64f\n" + "ldr x20, [%x[input_ptr], x26, LSL #0x3]\n" + "add x20, x20, x21, LSL #3\n" + "ldr x24, [x20, #0x0]\n" + "ldr x23, [x20, #0x8]\n" + "cbnz x26, 65f\n" + "ldr x20, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x24, x24, x20, LSL #2\n" + "add x23, x23, x20, LSL #2\n" + "b 65f\n" + "64:" // Height 2: setup direct input + "mov x24, %x[input_ptr]\n" + "add x23, x24, x21, LSL #2\n" + "65:" // Height 2: input setup done + "cmp x25, #0x4\n" + "blt 68f\n" + "ld1 { v0.4s }, [x24], #0x10\n" + "ld1 { v1.4s }, [x23], #0x10\n" + "cmp x25, #0x8\n" + "ldr q4, [x28, #0x0]\n" + "ldr q5, [x28, #0x10]\n" + "ldr q6, [x28, #0x20]\n" + "ldr q7, [x28, #0x30]\n" + "blt 67f\n" + "66:" // Height 2: Multiply loop: Main loop head + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + "sub x25, x25, #0x4\n" + "prfm pldl1keep, [x24, #0x80]\n" + "prfm pldl1keep, [x23, #0x80]\n" + "cmp x25, #0x8\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + "ld1 { v1.4s }, [x23], #0x10\n" + ".inst 0x6e44ec08 // bfmmla v8.4s, v0.8h, v4.8h\n" + "ldr q4, [x28, #0x40]\n" + ".inst 0x6e45ec0e // bfmmla v14.4s, v0.8h, v5.8h\n" + "ldr q5, [x28, #0x50]\n" + ".inst 0x6e46ec09 // bfmmla v9.4s, v0.8h, v6.8h\n" + "ldr q6, [x28, #0x60]\n" + ".inst 0x6e47ec0f // bfmmla v15.4s, v0.8h, v7.8h\n" + "ldr q7, [x28, #0x70]\n" + ".inst 0x6e44ec0a // bfmmla v10.4s, v0.8h, v4.8h\n" + "ldr q4, [x28, #0x80]\n" + ".inst 0x6e45ec10 // bfmmla v16.4s, v0.8h, v5.8h\n" + "ldr q5, [x28, #0x90]\n" + ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n" + "ldr q6, [x28, #0xa0]\n" + ".inst 0x6e47ec11 // bfmmla v17.4s, v0.8h, v7.8h\n" + "ldr q7, [x28, #0xb0]\n" + "add x28, x28, #0xc0\n" + ".inst 0x6e44ec0c // bfmmla v12.4s, v0.8h, v4.8h\n" + "ldr q4, [x28, #0x0]\n" + ".inst 0x6e45ec12 // bfmmla v18.4s, v0.8h, v5.8h\n" + "ldr q5, [x28, #0x10]\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + "ldr q6, [x28, #0x20]\n" + ".inst 0x6e47ec13 // bfmmla v19.4s, v0.8h, v7.8h\n" + "ld1 { v0.4s }, [x24], #0x10\n" + "ldr q7, [x28, #0x30]\n" + "bge 66b\n" + "67:" // Height 2: Multiply loop: Single iteration only + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + "sub x25, x25, #0x4\n" + "prfm pldl1keep, [x24, #0x80]\n" + "prfm pldl1keep, [x23, #0x80]\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + ".inst 0x6e44ec08 // bfmmla v8.4s, v0.8h, v4.8h\n" + "ldr q4, [x28, #0x40]\n" + ".inst 0x6e45ec0e // bfmmla v14.4s, v0.8h, v5.8h\n" + "ldr q5, [x28, #0x50]\n" + ".inst 0x6e46ec09 // bfmmla v9.4s, v0.8h, v6.8h\n" + "ldr q6, [x28, #0x60]\n" + ".inst 0x6e47ec0f // bfmmla v15.4s, v0.8h, v7.8h\n" + "ldr q7, [x28, #0x70]\n" + ".inst 0x6e44ec0a // bfmmla v10.4s, v0.8h, v4.8h\n" + "ldr q4, [x28, #0x80]\n" + ".inst 0x6e45ec10 // bfmmla v16.4s, v0.8h, v5.8h\n" + "ldr q5, [x28, #0x90]\n" + ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n" + "ldr q6, [x28, #0xa0]\n" + ".inst 0x6e47ec11 // bfmmla v17.4s, v0.8h, v7.8h\n" + "ldr q7, [x28, #0xb0]\n" + "add x28, x28, #0xc0\n" + ".inst 0x6e44ec0c // bfmmla v12.4s, v0.8h, v4.8h\n" + ".inst 0x6e45ec12 // bfmmla v18.4s, v0.8h, v5.8h\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + ".inst 0x6e47ec13 // bfmmla v19.4s, v0.8h, v7.8h\n" + "68:" // Height 2: Multiply loop: Main loop skip + "cbz x25, 71f\n" + "cbz x25, 71f\n" + "tbz x25, #1, 69f\n" + "ldr d0, [x24], #0x8\n" + "ldr d1, [x23], #0x8\n" + "tbz x25, #0, 70f\n" + "ld1 { v0.s }[2], [x24]\n" + "ld1 { v1.s }[2], [x23]\n" + "b 70f\n" + "69:" // Height 2: Multiply loop: Ragged operand read: partial_1_0 + "ldr s0, [x24, #0x0]\n" + "ldr s1, [x23, #0x0]\n" + "70:" // Height 2: Multiply loop: Ragged operand read: Done + "ldr q4, [x28, #0x0]\n" + "ldr q5, [x28, #0x10]\n" + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + "ldr q6, [x28, #0x20]\n" + "ldr q7, [x28, #0x30]\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + ".inst 0x6e44ec08 // bfmmla v8.4s, v0.8h, v4.8h\n" + "ldr q4, [x28, #0x40]\n" + ".inst 0x6e45ec0e // bfmmla v14.4s, v0.8h, v5.8h\n" + "ldr q5, [x28, #0x50]\n" + ".inst 0x6e46ec09 // bfmmla v9.4s, v0.8h, v6.8h\n" + "ldr q6, [x28, #0x60]\n" + ".inst 0x6e47ec0f // bfmmla v15.4s, v0.8h, v7.8h\n" + "ldr q7, [x28, #0x70]\n" + ".inst 0x6e44ec0a // bfmmla v10.4s, v0.8h, v4.8h\n" + "ldr q4, [x28, #0x80]\n" + ".inst 0x6e45ec10 // bfmmla v16.4s, v0.8h, v5.8h\n" + "ldr q5, [x28, #0x90]\n" + ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n" + "ldr q6, [x28, #0xa0]\n" + ".inst 0x6e47ec11 // bfmmla v17.4s, v0.8h, v7.8h\n" + "ldr q7, [x28, #0xb0]\n" + "add x28, x28, #0xc0\n" + ".inst 0x6e44ec0c // bfmmla v12.4s, v0.8h, v4.8h\n" + ".inst 0x6e45ec12 // bfmmla v18.4s, v0.8h, v5.8h\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + ".inst 0x6e47ec13 // bfmmla v19.4s, v0.8h, v7.8h\n" + "71:" // Height 2: Multiply loop: No odd multiplies + "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x26, x26, #0x1\n" + "cmp x26, x20\n" + "bne 63b\n" + "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n" + "uzp1 v4.2d, v8.2d, v14.2d\n" + "uzp2 v8.2d, v8.2d, v14.2d\n" + "prfm pstl1keep, [x27, #0x0]\n" + "uzp1 v14.2d, v9.2d, v15.2d\n" + "uzp2 v9.2d, v9.2d, v15.2d\n" + "uzp1 v15.2d, v10.2d, v16.2d\n" + "uzp2 v10.2d, v10.2d, v16.2d\n" + "add x24, x27, x20, LSL #2\n" + "uzp1 v16.2d, v11.2d, v17.2d\n" + "uzp2 v11.2d, v11.2d, v17.2d\n" + "prfm pstl1keep, [x24, #0x0]\n" + "uzp1 v17.2d, v12.2d, v18.2d\n" + "uzp2 v12.2d, v12.2d, v18.2d\n" + "uzp1 v18.2d, v13.2d, v19.2d\n" + "uzp2 v13.2d, v13.2d, v19.2d\n" + "tbz %x[flags], #1, 72f\n" + "add x21, %x[args_ptr], %[offset_max]\n" + "add x20, %x[args_ptr], %[offset_min]\n" + "ld1r { v22.4s }, [x21]\n" + "ld1r { v21.4s }, [x20]\n" + "fmin v4.4s, v4.4s, v22.4s\n" + "fmin v14.4s, v14.4s, v22.4s\n" + "fmin v15.4s, v15.4s, v22.4s\n" + "fmin v16.4s, v16.4s, v22.4s\n" + "fmin v17.4s, v17.4s, v22.4s\n" + "fmin v18.4s, v18.4s, v22.4s\n" + "fmin v8.4s, v8.4s, v22.4s\n" + "fmin v9.4s, v9.4s, v22.4s\n" + "fmin v10.4s, v10.4s, v22.4s\n" + "fmin v11.4s, v11.4s, v22.4s\n" + "fmin v12.4s, v12.4s, v22.4s\n" + "fmin v13.4s, v13.4s, v22.4s\n" + "fmax v4.4s, v4.4s, v21.4s\n" + "fmax v14.4s, v14.4s, v21.4s\n" + "fmax v15.4s, v15.4s, v21.4s\n" + "fmax v16.4s, v16.4s, v21.4s\n" + "fmax v17.4s, v17.4s, v21.4s\n" + "fmax v18.4s, v18.4s, v21.4s\n" + "fmax v8.4s, v8.4s, v21.4s\n" + "fmax v9.4s, v9.4s, v21.4s\n" + "fmax v10.4s, v10.4s, v21.4s\n" + "fmax v11.4s, v11.4s, v21.4s\n" + "fmax v12.4s, v12.4s, v21.4s\n" + "fmax v13.4s, v13.4s, v21.4s\n" + "72:" // Height 2: No activation + "cmp x9, #0x18\n" + "bge 85f\n" + "tbz x9, #4, 76f\n" + "st1 { v4.4s }, [x27], #0x10\n" + "st1 { v14.4s }, [x27], #0x10\n" + "st1 { v15.4s }, [x27], #0x10\n" + "st1 { v16.4s }, [x27], #0x10\n" + "st1 { v8.4s }, [x24], #0x10\n" + "st1 { v9.4s }, [x24], #0x10\n" + "st1 { v10.4s }, [x24], #0x10\n" + "st1 { v11.4s }, [x24], #0x10\n" + "tbz x9, #2, 74f\n" + "st1 { v17.4s }, [x27], #0x10\n" + "st1 { v12.4s }, [x24], #0x10\n" + "tbz x9, #1, 73f\n" + "str d18, [x27], #0x8\n" + "str d13, [x24], #0x8\n" + "tbz x9, #0, 84f\n" + "st1 { v18.s }[2], [x27]\n" + "st1 { v13.s }[2], [x24]\n" + "b 84f\n" + "73:" // Height 2: Partial direct writeback: partial_1_20 + "tbz x9, #0, 84f\n" + "str s18, [x27, #0x0]\n" + "str s13, [x24, #0x0]\n" + "b 84f\n" + "74:" // Height 2: Partial direct writeback: partial_2_16 + "tbz x9, #1, 75f\n" + "str d17, [x27], #0x8\n" + "str d12, [x24], #0x8\n" + "tbz x9, #0, 84f\n" + "st1 { v17.s }[2], [x27]\n" + "st1 { v12.s }[2], [x24]\n" + "b 84f\n" + "75:" // Height 2: Partial direct writeback: partial_1_16 + "tbz x9, #0, 84f\n" + "str s17, [x27, #0x0]\n" + "str s12, [x24, #0x0]\n" + "b 84f\n" + "76:" // Height 2: Partial direct writeback: partial_8_0 + "tbz x9, #3, 80f\n" + "st1 { v4.4s }, [x27], #0x10\n" + "st1 { v14.4s }, [x27], #0x10\n" + "st1 { v8.4s }, [x24], #0x10\n" + "st1 { v9.4s }, [x24], #0x10\n" + "tbz x9, #2, 78f\n" + "st1 { v15.4s }, [x27], #0x10\n" + "st1 { v10.4s }, [x24], #0x10\n" + "tbz x9, #1, 77f\n" + "str d16, [x27], #0x8\n" + "str d11, [x24], #0x8\n" + "tbz x9, #0, 84f\n" + "st1 { v16.s }[2], [x27]\n" + "st1 { v11.s }[2], [x24]\n" + "b 84f\n" + "77:" // Height 2: Partial direct writeback: partial_1_12 + "tbz x9, #0, 84f\n" + "str s16, [x27, #0x0]\n" + "str s11, [x24, #0x0]\n" + "b 84f\n" + "78:" // Height 2: Partial direct writeback: partial_2_8 + "tbz x9, #1, 79f\n" + "str d15, [x27], #0x8\n" + "str d10, [x24], #0x8\n" + "tbz x9, #0, 84f\n" + "st1 { v15.s }[2], [x27]\n" + "st1 { v10.s }[2], [x24]\n" + "b 84f\n" + "79:" // Height 2: Partial direct writeback: partial_1_8 + "tbz x9, #0, 84f\n" + "str s15, [x27, #0x0]\n" + "str s10, [x24, #0x0]\n" + "b 84f\n" + "80:" // Height 2: Partial direct writeback: partial_4_0 + "tbz x9, #2, 82f\n" + "st1 { v4.4s }, [x27], #0x10\n" + "st1 { v8.4s }, [x24], #0x10\n" + "tbz x9, #1, 81f\n" + "str d14, [x27], #0x8\n" + "str d9, [x24], #0x8\n" + "tbz x9, #0, 84f\n" + "st1 { v14.s }[2], [x27]\n" + "st1 { v9.s }[2], [x24]\n" + "b 84f\n" + "81:" // Height 2: Partial direct writeback: partial_1_4 + "tbz x9, #0, 84f\n" + "str s14, [x27, #0x0]\n" + "str s9, [x24, #0x0]\n" + "b 84f\n" + "82:" // Height 2: Partial direct writeback: partial_2_0 + "tbz x9, #1, 83f\n" + "str d4, [x27], #0x8\n" + "str d8, [x24], #0x8\n" + "tbz x9, #0, 84f\n" + "st1 { v4.s }[2], [x27]\n" + "st1 { v8.s }[2], [x24]\n" + "b 84f\n" + "83:" // Height 2: Partial direct writeback: partial_1_0 + "str s4, [x27, #0x0]\n" + "str s8, [x24, #0x0]\n" + "84:" // Height 2: Partial direct writeback: Done + "b 86f\n" + "85:" // Height 2: Full writeback + "str q4, [x27, #0x0]\n" + "str q14, [x27, #0x10]\n" + "str q15, [x27, #0x20]\n" + "str q16, [x27, #0x30]\n" + "str q17, [x27, #0x40]\n" + "str q18, [x27, #0x50]\n" + "add x27, x27, #0x60\n" + "str q8, [x24, #0x0]\n" + "str q9, [x24, #0x10]\n" + "str q10, [x24, #0x20]\n" + "str q11, [x24, #0x30]\n" + "str q12, [x24, #0x40]\n" + "str q13, [x24, #0x50]\n" + "86:" // Height 2: Writeback done + "subs x9, x9, #0x18\n" + "bgt 45b\n" + "b 174f\n" + "87:" // Height 3 + "ldr x9, [%x[args_ptr], %[offsetof_N]]\n" + "ldr x28, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "ldr x27, [%x[args_ptr], %[offsetof_output_ptr]]\n" + "88:" // Height 3: Column loop + "cbz x28, 89f\n" + "ldr q8, [x28, #0x0]\n" + "ldr q9, [x28, #0x10]\n" + "ldr q10, [x28, #0x20]\n" + "ldr q11, [x28, #0x30]\n" + "ldr q12, [x28, #0x40]\n" + "ldr q13, [x28, #0x50]\n" + "add x28, x28, #0x60\n" + "zip2 v14.2d, v8.2d, v8.2d\n" + "zip1 v8.2d, v8.2d, v8.2d\n" + "zip2 v15.2d, v9.2d, v9.2d\n" + "zip1 v9.2d, v9.2d, v9.2d\n" + "zip2 v16.2d, v10.2d, v10.2d\n" + "zip1 v10.2d, v10.2d, v10.2d\n" + "zip2 v17.2d, v11.2d, v11.2d\n" + "zip1 v11.2d, v11.2d, v11.2d\n" + "zip2 v18.2d, v12.2d, v12.2d\n" + "zip1 v12.2d, v12.2d, v12.2d\n" + "zip2 v19.2d, v13.2d, v13.2d\n" + "zip1 v13.2d, v13.2d, v13.2d\n" + "mov v20.16b, v8.16b\n" + "mov v26.16b, v14.16b\n" + "mov v21.16b, v9.16b\n" + "mov v27.16b, v15.16b\n" + "mov v22.16b, v10.16b\n" + "mov v28.16b, v16.16b\n" + "mov v23.16b, v11.16b\n" + "mov v29.16b, v17.16b\n" + "mov v24.16b, v12.16b\n" + "mov v30.16b, v18.16b\n" + "mov v25.16b, v13.16b\n" + "mov v31.16b, v19.16b\n" + "b 105f\n" + "89:" // Height 3: no bias + "tbz %x[flags], #0, 104f\n" + "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n" + "cmp x9, #0x18\n" + "add x24, x27, x20, LSL #2\n" + "add x23, x24, x20, LSL #2\n" + "bge 102f\n" + "tbz x9, #4, 93f\n" + "ld1 { v9.4s }, [x27], #0x10\n" + "ld1 { v14.4s }, [x24], #0x10\n" + "ld1 { v21.4s }, [x23], #0x10\n" + "ld1 { v10.4s }, [x27], #0x10\n" + "ld1 { v15.4s }, [x24], #0x10\n" + "ld1 { v22.4s }, [x23], #0x10\n" + "ld1 { v11.4s }, [x27], #0x10\n" + "ld1 { v16.4s }, [x24], #0x10\n" + "ld1 { v23.4s }, [x23], #0x10\n" + "ld1 { v12.4s }, [x27], #0x10\n" + "ld1 { v17.4s }, [x24], #0x10\n" + "ld1 { v24.4s }, [x23], #0x10\n" + "tbz x9, #2, 91f\n" + "ld1 { v13.4s }, [x27], #0x10\n" + "ld1 { v18.4s }, [x24], #0x10\n" + "ld1 { v25.4s }, [x23], #0x10\n" + "tbz x9, #1, 90f\n" + "ldr d20, [x27], #0x8\n" + "ldr d19, [x24], #0x8\n" + "mov x20, #0x58\n" + "ldr d4, [x23], #0x8\n" + "tbz x9, #0, 101f\n" + "ld1 { v20.s }[2], [x27]\n" + "ld1 { v19.s }[2], [x24]\n" + "ld1 { v4.s }[2], [x23]\n" + "b 101f\n" + "90:" // Height 3: Partial accumulate: partial_1_20 + "mov x20, #0x50\n" + "tbz x9, #0, 101f\n" + "ldr s20, [x27, #0x0]\n" + "ldr s19, [x24, #0x0]\n" + "ldr s4, [x23, #0x0]\n" + "b 101f\n" + "91:" // Height 3: Partial accumulate: partial_2_16 + "tbz x9, #1, 92f\n" + "ldr d13, [x27], #0x8\n" + "ldr d18, [x24], #0x8\n" + "mov x20, #0x48\n" + "ldr d25, [x23], #0x8\n" + "tbz x9, #0, 101f\n" + "ld1 { v13.s }[2], [x27]\n" + "ld1 { v18.s }[2], [x24]\n" + "ld1 { v25.s }[2], [x23]\n" + "b 101f\n" + "92:" // Height 3: Partial accumulate: partial_1_16 + "mov x20, #0x40\n" + "tbz x9, #0, 101f\n" + "ldr s13, [x27, #0x0]\n" + "ldr s18, [x24, #0x0]\n" + "ldr s25, [x23, #0x0]\n" + "b 101f\n" + "93:" // Height 3: Partial accumulate: partial_8_0 + "tbz x9, #3, 97f\n" + "ld1 { v9.4s }, [x27], #0x10\n" + "ld1 { v14.4s }, [x24], #0x10\n" + "ld1 { v21.4s }, [x23], #0x10\n" + "ld1 { v10.4s }, [x27], #0x10\n" + "ld1 { v15.4s }, [x24], #0x10\n" + "ld1 { v22.4s }, [x23], #0x10\n" + "tbz x9, #2, 95f\n" + "ld1 { v11.4s }, [x27], #0x10\n" + "ld1 { v16.4s }, [x24], #0x10\n" + "ld1 { v23.4s }, [x23], #0x10\n" + "tbz x9, #1, 94f\n" + "ldr d12, [x27], #0x8\n" + "ldr d17, [x24], #0x8\n" + "mov x20, #0x38\n" + "ldr d24, [x23], #0x8\n" + "tbz x9, #0, 101f\n" + "ld1 { v12.s }[2], [x27]\n" + "ld1 { v17.s }[2], [x24]\n" + "ld1 { v24.s }[2], [x23]\n" + "b 101f\n" + "94:" // Height 3: Partial accumulate: partial_1_12 + "mov x20, #0x30\n" + "tbz x9, #0, 101f\n" + "ldr s12, [x27, #0x0]\n" + "ldr s17, [x24, #0x0]\n" + "ldr s24, [x23, #0x0]\n" + "b 101f\n" + "95:" // Height 3: Partial accumulate: partial_2_8 + "tbz x9, #1, 96f\n" + "ldr d11, [x27], #0x8\n" + "ldr d16, [x24], #0x8\n" + "mov x20, #0x28\n" + "ldr d23, [x23], #0x8\n" + "tbz x9, #0, 101f\n" + "ld1 { v11.s }[2], [x27]\n" + "ld1 { v16.s }[2], [x24]\n" + "ld1 { v23.s }[2], [x23]\n" + "b 101f\n" + "96:" // Height 3: Partial accumulate: partial_1_8 + "mov x20, #0x20\n" + "tbz x9, #0, 101f\n" + "ldr s11, [x27, #0x0]\n" + "ldr s16, [x24, #0x0]\n" + "ldr s23, [x23, #0x0]\n" + "b 101f\n" + "97:" // Height 3: Partial accumulate: partial_4_0 + "tbz x9, #2, 99f\n" + "ld1 { v9.4s }, [x27], #0x10\n" + "ld1 { v14.4s }, [x24], #0x10\n" + "ld1 { v21.4s }, [x23], #0x10\n" + "tbz x9, #1, 98f\n" + "ldr d10, [x27], #0x8\n" + "ldr d15, [x24], #0x8\n" + "mov x20, #0x18\n" + "ldr d22, [x23], #0x8\n" + "tbz x9, #0, 101f\n" + "ld1 { v10.s }[2], [x27]\n" + "ld1 { v15.s }[2], [x24]\n" + "ld1 { v22.s }[2], [x23]\n" + "b 101f\n" + "98:" // Height 3: Partial accumulate: partial_1_4 + "mov x20, #0x10\n" + "tbz x9, #0, 101f\n" + "ldr s10, [x27, #0x0]\n" + "ldr s15, [x24, #0x0]\n" + "ldr s22, [x23, #0x0]\n" + "b 101f\n" + "99:" // Height 3: Partial accumulate: partial_2_0 + "tbz x9, #1, 100f\n" + "ldr d9, [x27], #0x8\n" + "ldr d14, [x24], #0x8\n" + "mov x20, #0x8\n" + "ldr d21, [x23], #0x8\n" + "tbz x9, #0, 101f\n" + "ld1 { v9.s }[2], [x27]\n" + "ld1 { v14.s }[2], [x24]\n" + "ld1 { v21.s }[2], [x23]\n" + "b 101f\n" + "100:" // Height 3: Partial accumulate: partial_1_0 + "ldr s9, [x27, #0x0]\n" + "ldr s14, [x24, #0x0]\n" + "mov x20, #0x0\n" + "ldr s21, [x23, #0x0]\n" + "101:" // Height 3: Partial accumulate: Done + "sub x27, x27, x20\n" + "b 103f\n" + "102:" // Height 3: full accumulate + "ldr q9, [x27, #0x0]\n" + "ldr q10, [x27, #0x10]\n" + "ldr q11, [x27, #0x20]\n" + "ldr q12, [x27, #0x30]\n" + "ldr q13, [x27, #0x40]\n" + "ldr q20, [x27, #0x50]\n" + "ldr q14, [x24, #0x0]\n" + "ldr q15, [x24, #0x10]\n" + "ldr q16, [x24, #0x20]\n" + "ldr q17, [x24, #0x30]\n" + "ldr q18, [x24, #0x40]\n" + "ldr q19, [x24, #0x50]\n" + "ldr q21, [x23, #0x0]\n" + "ldr q22, [x23, #0x10]\n" + "ldr q23, [x23, #0x20]\n" + "ldr q24, [x23, #0x30]\n" + "ldr q25, [x23, #0x40]\n" + "ldr q4, [x23, #0x50]\n" + "103:" // Height 3: MMLA fixup + "zip1 v8.2d, v9.2d, v14.2d\n" + "zip2 v14.2d, v9.2d, v14.2d\n" + "zip1 v9.2d, v10.2d, v15.2d\n" + "zip2 v15.2d, v10.2d, v15.2d\n" + "zip1 v10.2d, v11.2d, v16.2d\n" + "zip2 v16.2d, v11.2d, v16.2d\n" + "zip1 v11.2d, v12.2d, v17.2d\n" + "zip2 v17.2d, v12.2d, v17.2d\n" + "zip1 v12.2d, v13.2d, v18.2d\n" + "zip2 v18.2d, v13.2d, v18.2d\n" + "zip1 v13.2d, v20.2d, v19.2d\n" + "zip2 v19.2d, v20.2d, v19.2d\n" + "zip1 v20.2d, v21.2d, v26.2d\n" + "zip2 v26.2d, v21.2d, v26.2d\n" + "zip1 v21.2d, v22.2d, v27.2d\n" + "zip2 v27.2d, v22.2d, v27.2d\n" + "zip1 v22.2d, v23.2d, v28.2d\n" + "zip2 v28.2d, v23.2d, v28.2d\n" + "zip1 v23.2d, v24.2d, v29.2d\n" + "zip2 v29.2d, v24.2d, v29.2d\n" + "zip1 v24.2d, v25.2d, v30.2d\n" + "zip2 v30.2d, v25.2d, v30.2d\n" + "zip1 v25.2d, v4.2d, v31.2d\n" + "zip2 v31.2d, v4.2d, v31.2d\n" + "b 105f\n" + "104:" // Height 3: no accumulate + "movi v8.16b, #0x0\n" + "movi v9.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "movi v15.16b, #0x0\n" + "movi v16.16b, #0x0\n" + "movi v17.16b, #0x0\n" + "movi v18.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "movi v20.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v22.16b, #0x0\n" + "movi v23.16b, #0x0\n" + "movi v24.16b, #0x0\n" + "movi v25.16b, #0x0\n" + "movi v26.16b, #0x0\n" + "movi v27.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "movi v29.16b, #0x0\n" + "movi v30.16b, #0x0\n" + "movi v31.16b, #0x0\n" + "105:" // Height 3: setup done + "mov x26, #0x0\n" + "106:" // Height 3: String loop + "ldr x20, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr x21, [%x[args_ptr], %[offsetof_input_offset]]\n" + "ldr w25, [x20, x26, LSL #0x2]\n" + "tbz %x[flags], #3, 107f\n" + "ldr x20, [%x[input_ptr], x26, LSL #0x3]\n" + "add x20, x20, x21, LSL #3\n" + "ldr x24, [x20, #0x0]\n" + "ldr x23, [x20, #0x8]\n" + "ldr x22, [x20, #0x10]\n" + "cbnz x26, 108f\n" + "ldr x20, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x24, x24, x20, LSL #2\n" + "add x23, x23, x20, LSL #2\n" + "add x22, x22, x20, LSL #2\n" + "b 108f\n" + "107:" // Height 3: setup direct input + "mov x24, %x[input_ptr]\n" + "add x23, x24, x21, LSL #2\n" + "add x22, x23, x21, LSL #2\n" + "108:" // Height 3: input setup done + "cmp x25, #0x4\n" + "blt 111f\n" + "ld1 { v0.4s }, [x24], #0x10\n" + "ld1 { v1.4s }, [x23], #0x10\n" + "cmp x25, #0x8\n" + "ld1 { v2.4s }, [x22], #0x10\n" + "ldr q4, [x28, #0x0]\n" + "ldr q5, [x28, #0x10]\n" + "ldr q6, [x28, #0x20]\n" + "ldr q7, [x28, #0x30]\n" + "blt 110f\n" + "109:" // Height 3: Multiply loop: Main loop head + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n" + "sub x25, x25, #0x4\n" + "prfm pldl1keep, [x24, #0x80]\n" + "cmp x25, #0x8\n" + "prfm pldl1keep, [x23, #0x80]\n" + "prfm pldl1keep, [x22, #0x80]\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + "ld1 { v1.4s }, [x23], #0x10\n" + ".inst 0x6e44ec54 // bfmmla v20.4s, v2.8h, v4.8h\n" + ".inst 0x6e45ec5a // bfmmla v26.4s, v2.8h, v5.8h\n" + ".inst 0x6e46ec55 // bfmmla v21.4s, v2.8h, v6.8h\n" + ".inst 0x6e47ec5b // bfmmla v27.4s, v2.8h, v7.8h\n" + ".inst 0x6e44ec08 // bfmmla v8.4s, v0.8h, v4.8h\n" + "ldr q4, [x28, #0x40]\n" + ".inst 0x6e45ec0e // bfmmla v14.4s, v0.8h, v5.8h\n" + "ldr q5, [x28, #0x50]\n" + ".inst 0x6e46ec09 // bfmmla v9.4s, v0.8h, v6.8h\n" + "ldr q6, [x28, #0x60]\n" + ".inst 0x6e47ec0f // bfmmla v15.4s, v0.8h, v7.8h\n" + "ldr q7, [x28, #0x70]\n" + ".inst 0x6e44ec0a // bfmmla v10.4s, v0.8h, v4.8h\n" + ".inst 0x6e44ec56 // bfmmla v22.4s, v2.8h, v4.8h\n" + "ldr q4, [x28, #0x80]\n" + ".inst 0x6e45ec10 // bfmmla v16.4s, v0.8h, v5.8h\n" + ".inst 0x6e45ec5c // bfmmla v28.4s, v2.8h, v5.8h\n" + "ldr q5, [x28, #0x90]\n" + ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec57 // bfmmla v23.4s, v2.8h, v6.8h\n" + "ldr q6, [x28, #0xa0]\n" + ".inst 0x6e47ec11 // bfmmla v17.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec5d // bfmmla v29.4s, v2.8h, v7.8h\n" + "ldr q7, [x28, #0xb0]\n" + "add x28, x28, #0xc0\n" + ".inst 0x6e44ec0c // bfmmla v12.4s, v0.8h, v4.8h\n" + ".inst 0x6e44ec58 // bfmmla v24.4s, v2.8h, v4.8h\n" + "ldr q4, [x28, #0x0]\n" + ".inst 0x6e45ec12 // bfmmla v18.4s, v0.8h, v5.8h\n" + ".inst 0x6e45ec5e // bfmmla v30.4s, v2.8h, v5.8h\n" + "ldr q5, [x28, #0x10]\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec59 // bfmmla v25.4s, v2.8h, v6.8h\n" + "ldr q6, [x28, #0x20]\n" + ".inst 0x6e47ec13 // bfmmla v19.4s, v0.8h, v7.8h\n" + "ld1 { v0.4s }, [x24], #0x10\n" + ".inst 0x6e47ec5f // bfmmla v31.4s, v2.8h, v7.8h\n" + "ld1 { v2.4s }, [x22], #0x10\n" + "ldr q7, [x28, #0x30]\n" + "bge 109b\n" + "110:" // Height 3: Multiply loop: Single iteration only + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n" + "sub x25, x25, #0x4\n" + "prfm pldl1keep, [x24, #0x80]\n" + "prfm pldl1keep, [x23, #0x80]\n" + "prfm pldl1keep, [x22, #0x80]\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + ".inst 0x6e44ec54 // bfmmla v20.4s, v2.8h, v4.8h\n" + ".inst 0x6e45ec5a // bfmmla v26.4s, v2.8h, v5.8h\n" + ".inst 0x6e46ec55 // bfmmla v21.4s, v2.8h, v6.8h\n" + ".inst 0x6e47ec5b // bfmmla v27.4s, v2.8h, v7.8h\n" + ".inst 0x6e44ec08 // bfmmla v8.4s, v0.8h, v4.8h\n" + "ldr q4, [x28, #0x40]\n" + ".inst 0x6e45ec0e // bfmmla v14.4s, v0.8h, v5.8h\n" + "ldr q5, [x28, #0x50]\n" + ".inst 0x6e46ec09 // bfmmla v9.4s, v0.8h, v6.8h\n" + "ldr q6, [x28, #0x60]\n" + ".inst 0x6e47ec0f // bfmmla v15.4s, v0.8h, v7.8h\n" + "ldr q7, [x28, #0x70]\n" + ".inst 0x6e44ec0a // bfmmla v10.4s, v0.8h, v4.8h\n" + ".inst 0x6e44ec56 // bfmmla v22.4s, v2.8h, v4.8h\n" + "ldr q4, [x28, #0x80]\n" + ".inst 0x6e45ec10 // bfmmla v16.4s, v0.8h, v5.8h\n" + ".inst 0x6e45ec5c // bfmmla v28.4s, v2.8h, v5.8h\n" + "ldr q5, [x28, #0x90]\n" + ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec57 // bfmmla v23.4s, v2.8h, v6.8h\n" + "ldr q6, [x28, #0xa0]\n" + ".inst 0x6e47ec11 // bfmmla v17.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec5d // bfmmla v29.4s, v2.8h, v7.8h\n" + "ldr q7, [x28, #0xb0]\n" + "add x28, x28, #0xc0\n" + ".inst 0x6e44ec0c // bfmmla v12.4s, v0.8h, v4.8h\n" + ".inst 0x6e44ec58 // bfmmla v24.4s, v2.8h, v4.8h\n" + ".inst 0x6e45ec12 // bfmmla v18.4s, v0.8h, v5.8h\n" + ".inst 0x6e45ec5e // bfmmla v30.4s, v2.8h, v5.8h\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec59 // bfmmla v25.4s, v2.8h, v6.8h\n" + ".inst 0x6e47ec13 // bfmmla v19.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec5f // bfmmla v31.4s, v2.8h, v7.8h\n" + "111:" // Height 3: Multiply loop: Main loop skip + "cbz x25, 114f\n" + "cbz x25, 114f\n" + "tbz x25, #1, 112f\n" + "ldr d0, [x24], #0x8\n" + "ldr d1, [x23], #0x8\n" + "ldr d2, [x22], #0x8\n" + "tbz x25, #0, 113f\n" + "ld1 { v0.s }[2], [x24]\n" + "ld1 { v1.s }[2], [x23]\n" + "ld1 { v2.s }[2], [x22]\n" + "b 113f\n" + "112:" // Height 3: Multiply loop: Ragged operand read: partial_1_0 + "ldr s0, [x24, #0x0]\n" + "ldr s1, [x23, #0x0]\n" + "ldr s2, [x22, #0x0]\n" + "113:" // Height 3: Multiply loop: Ragged operand read: Done + "ldr q4, [x28, #0x0]\n" + "ldr q5, [x28, #0x10]\n" + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n" + "ldr q6, [x28, #0x20]\n" + "ldr q7, [x28, #0x30]\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + ".inst 0x6e44ec54 // bfmmla v20.4s, v2.8h, v4.8h\n" + ".inst 0x6e45ec5a // bfmmla v26.4s, v2.8h, v5.8h\n" + ".inst 0x6e44ec08 // bfmmla v8.4s, v0.8h, v4.8h\n" + "ldr q4, [x28, #0x40]\n" + ".inst 0x6e45ec0e // bfmmla v14.4s, v0.8h, v5.8h\n" + "ldr q5, [x28, #0x50]\n" + ".inst 0x6e46ec09 // bfmmla v9.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec55 // bfmmla v21.4s, v2.8h, v6.8h\n" + "ldr q6, [x28, #0x60]\n" + ".inst 0x6e47ec0f // bfmmla v15.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec5b // bfmmla v27.4s, v2.8h, v7.8h\n" + "ldr q7, [x28, #0x70]\n" + ".inst 0x6e44ec0a // bfmmla v10.4s, v0.8h, v4.8h\n" + ".inst 0x6e44ec56 // bfmmla v22.4s, v2.8h, v4.8h\n" + "ldr q4, [x28, #0x80]\n" + ".inst 0x6e45ec10 // bfmmla v16.4s, v0.8h, v5.8h\n" + ".inst 0x6e45ec5c // bfmmla v28.4s, v2.8h, v5.8h\n" + "ldr q5, [x28, #0x90]\n" + ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec57 // bfmmla v23.4s, v2.8h, v6.8h\n" + "ldr q6, [x28, #0xa0]\n" + ".inst 0x6e47ec11 // bfmmla v17.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec5d // bfmmla v29.4s, v2.8h, v7.8h\n" + "ldr q7, [x28, #0xb0]\n" + ".inst 0x6e44ec0c // bfmmla v12.4s, v0.8h, v4.8h\n" + "add x28, x28, #0xc0\n" + ".inst 0x6e44ec58 // bfmmla v24.4s, v2.8h, v4.8h\n" + ".inst 0x6e45ec12 // bfmmla v18.4s, v0.8h, v5.8h\n" + ".inst 0x6e45ec5e // bfmmla v30.4s, v2.8h, v5.8h\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec59 // bfmmla v25.4s, v2.8h, v6.8h\n" + ".inst 0x6e47ec13 // bfmmla v19.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec5f // bfmmla v31.4s, v2.8h, v7.8h\n" + "114:" // Height 3: Multiply loop: No odd multiplies + "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x26, x26, #0x1\n" + "cmp x26, x20\n" + "bne 106b\n" + "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n" + "uzp1 v4.2d, v8.2d, v14.2d\n" + "uzp2 v8.2d, v8.2d, v14.2d\n" + "prfm pstl1keep, [x27, #0x0]\n" + "uzp1 v14.2d, v9.2d, v15.2d\n" + "uzp2 v9.2d, v9.2d, v15.2d\n" + "uzp1 v15.2d, v10.2d, v16.2d\n" + "uzp2 v10.2d, v10.2d, v16.2d\n" + "add x24, x27, x20, LSL #2\n" + "add x23, x24, x20, LSL #2\n" + "uzp1 v16.2d, v11.2d, v17.2d\n" + "uzp2 v11.2d, v11.2d, v17.2d\n" + "prfm pstl1keep, [x24, #0x0]\n" + "prfm pstl1keep, [x23, #0x0]\n" + "uzp1 v17.2d, v12.2d, v18.2d\n" + "uzp2 v12.2d, v12.2d, v18.2d\n" + "uzp1 v18.2d, v13.2d, v19.2d\n" + "uzp2 v13.2d, v13.2d, v19.2d\n" + "uzp1 v20.2d, v20.2d, v26.2d\n" + "uzp1 v21.2d, v21.2d, v27.2d\n" + "uzp1 v22.2d, v22.2d, v28.2d\n" + "uzp1 v23.2d, v23.2d, v29.2d\n" + "uzp1 v24.2d, v24.2d, v30.2d\n" + "uzp1 v25.2d, v25.2d, v31.2d\n" + "tbz %x[flags], #1, 115f\n" + "add x21, %x[args_ptr], %[offset_max]\n" + "add x20, %x[args_ptr], %[offset_min]\n" + "ld1r { v1.4s }, [x21]\n" + "ld1r { v0.4s }, [x20]\n" + "fmin v4.4s, v4.4s, v1.4s\n" + "fmin v14.4s, v14.4s, v1.4s\n" + "fmin v15.4s, v15.4s, v1.4s\n" + "fmin v16.4s, v16.4s, v1.4s\n" + "fmin v17.4s, v17.4s, v1.4s\n" + "fmin v18.4s, v18.4s, v1.4s\n" + "fmin v8.4s, v8.4s, v1.4s\n" + "fmin v9.4s, v9.4s, v1.4s\n" + "fmin v10.4s, v10.4s, v1.4s\n" + "fmin v11.4s, v11.4s, v1.4s\n" + "fmin v12.4s, v12.4s, v1.4s\n" + "fmin v13.4s, v13.4s, v1.4s\n" + "fmin v20.4s, v20.4s, v1.4s\n" + "fmin v21.4s, v21.4s, v1.4s\n" + "fmin v22.4s, v22.4s, v1.4s\n" + "fmin v23.4s, v23.4s, v1.4s\n" + "fmin v24.4s, v24.4s, v1.4s\n" + "fmin v25.4s, v25.4s, v1.4s\n" + "fmax v4.4s, v4.4s, v0.4s\n" + "fmax v14.4s, v14.4s, v0.4s\n" + "fmax v15.4s, v15.4s, v0.4s\n" + "fmax v16.4s, v16.4s, v0.4s\n" + "fmax v17.4s, v17.4s, v0.4s\n" + "fmax v18.4s, v18.4s, v0.4s\n" + "fmax v8.4s, v8.4s, v0.4s\n" + "fmax v9.4s, v9.4s, v0.4s\n" + "fmax v10.4s, v10.4s, v0.4s\n" + "fmax v11.4s, v11.4s, v0.4s\n" + "fmax v12.4s, v12.4s, v0.4s\n" + "fmax v13.4s, v13.4s, v0.4s\n" + "fmax v20.4s, v20.4s, v0.4s\n" + "fmax v21.4s, v21.4s, v0.4s\n" + "fmax v22.4s, v22.4s, v0.4s\n" + "fmax v23.4s, v23.4s, v0.4s\n" + "fmax v24.4s, v24.4s, v0.4s\n" + "fmax v25.4s, v25.4s, v0.4s\n" + "115:" // Height 3: No activation + "cmp x9, #0x18\n" + "bge 128f\n" + "tbz x9, #4, 119f\n" + "st1 { v4.4s }, [x27], #0x10\n" + "st1 { v14.4s }, [x27], #0x10\n" + "st1 { v15.4s }, [x27], #0x10\n" + "st1 { v16.4s }, [x27], #0x10\n" + "st1 { v8.4s }, [x24], #0x10\n" + "st1 { v9.4s }, [x24], #0x10\n" + "st1 { v10.4s }, [x24], #0x10\n" + "st1 { v11.4s }, [x24], #0x10\n" + "st1 { v20.4s }, [x23], #0x10\n" + "st1 { v21.4s }, [x23], #0x10\n" + "st1 { v22.4s }, [x23], #0x10\n" + "st1 { v23.4s }, [x23], #0x10\n" + "tbz x9, #2, 117f\n" + "st1 { v17.4s }, [x27], #0x10\n" + "st1 { v12.4s }, [x24], #0x10\n" + "st1 { v24.4s }, [x23], #0x10\n" + "tbz x9, #1, 116f\n" + "str d18, [x27], #0x8\n" + "str d13, [x24], #0x8\n" + "str d25, [x23], #0x8\n" + "tbz x9, #0, 127f\n" + "st1 { v18.s }[2], [x27]\n" + "st1 { v13.s }[2], [x24]\n" + "st1 { v25.s }[2], [x23]\n" + "b 127f\n" + "116:" // Height 3: Partial direct writeback: partial_1_20 + "tbz x9, #0, 127f\n" + "str s18, [x27, #0x0]\n" + "str s13, [x24, #0x0]\n" + "str s25, [x23, #0x0]\n" + "b 127f\n" + "117:" // Height 3: Partial direct writeback: partial_2_16 + "tbz x9, #1, 118f\n" + "str d17, [x27], #0x8\n" + "str d12, [x24], #0x8\n" + "str d24, [x23], #0x8\n" + "tbz x9, #0, 127f\n" + "st1 { v17.s }[2], [x27]\n" + "st1 { v12.s }[2], [x24]\n" + "st1 { v24.s }[2], [x23]\n" + "b 127f\n" + "118:" // Height 3: Partial direct writeback: partial_1_16 + "tbz x9, #0, 127f\n" + "str s17, [x27, #0x0]\n" + "str s12, [x24, #0x0]\n" + "str s24, [x23, #0x0]\n" + "b 127f\n" + "119:" // Height 3: Partial direct writeback: partial_8_0 + "tbz x9, #3, 123f\n" + "st1 { v4.4s }, [x27], #0x10\n" + "st1 { v14.4s }, [x27], #0x10\n" + "st1 { v8.4s }, [x24], #0x10\n" + "st1 { v9.4s }, [x24], #0x10\n" + "st1 { v20.4s }, [x23], #0x10\n" + "st1 { v21.4s }, [x23], #0x10\n" + "tbz x9, #2, 121f\n" + "st1 { v15.4s }, [x27], #0x10\n" + "st1 { v10.4s }, [x24], #0x10\n" + "st1 { v22.4s }, [x23], #0x10\n" + "tbz x9, #1, 120f\n" + "str d16, [x27], #0x8\n" + "str d11, [x24], #0x8\n" + "str d23, [x23], #0x8\n" + "tbz x9, #0, 127f\n" + "st1 { v16.s }[2], [x27]\n" + "st1 { v11.s }[2], [x24]\n" + "st1 { v23.s }[2], [x23]\n" + "b 127f\n" + "120:" // Height 3: Partial direct writeback: partial_1_12 + "tbz x9, #0, 127f\n" + "str s16, [x27, #0x0]\n" + "str s11, [x24, #0x0]\n" + "str s23, [x23, #0x0]\n" + "b 127f\n" + "121:" // Height 3: Partial direct writeback: partial_2_8 + "tbz x9, #1, 122f\n" + "str d15, [x27], #0x8\n" + "str d10, [x24], #0x8\n" + "str d22, [x23], #0x8\n" + "tbz x9, #0, 127f\n" + "st1 { v15.s }[2], [x27]\n" + "st1 { v10.s }[2], [x24]\n" + "st1 { v22.s }[2], [x23]\n" + "b 127f\n" + "122:" // Height 3: Partial direct writeback: partial_1_8 + "tbz x9, #0, 127f\n" + "str s15, [x27, #0x0]\n" + "str s10, [x24, #0x0]\n" + "str s22, [x23, #0x0]\n" + "b 127f\n" + "123:" // Height 3: Partial direct writeback: partial_4_0 + "tbz x9, #2, 125f\n" + "st1 { v4.4s }, [x27], #0x10\n" + "st1 { v8.4s }, [x24], #0x10\n" + "st1 { v20.4s }, [x23], #0x10\n" + "tbz x9, #1, 124f\n" + "str d14, [x27], #0x8\n" + "str d9, [x24], #0x8\n" + "str d21, [x23], #0x8\n" + "tbz x9, #0, 127f\n" + "st1 { v14.s }[2], [x27]\n" + "st1 { v9.s }[2], [x24]\n" + "st1 { v21.s }[2], [x23]\n" + "b 127f\n" + "124:" // Height 3: Partial direct writeback: partial_1_4 + "tbz x9, #0, 127f\n" + "str s14, [x27, #0x0]\n" + "str s9, [x24, #0x0]\n" + "str s21, [x23, #0x0]\n" + "b 127f\n" + "125:" // Height 3: Partial direct writeback: partial_2_0 + "tbz x9, #1, 126f\n" + "str d4, [x27], #0x8\n" + "str d8, [x24], #0x8\n" + "str d20, [x23], #0x8\n" + "tbz x9, #0, 127f\n" + "st1 { v4.s }[2], [x27]\n" + "st1 { v8.s }[2], [x24]\n" + "st1 { v20.s }[2], [x23]\n" + "b 127f\n" + "126:" // Height 3: Partial direct writeback: partial_1_0 + "str s4, [x27, #0x0]\n" + "str s8, [x24, #0x0]\n" + "str s20, [x23, #0x0]\n" + "127:" // Height 3: Partial direct writeback: Done + "b 129f\n" + "128:" // Height 3: Full writeback + "str q4, [x27, #0x0]\n" + "str q14, [x27, #0x10]\n" + "str q15, [x27, #0x20]\n" + "str q16, [x27, #0x30]\n" + "str q17, [x27, #0x40]\n" + "str q18, [x27, #0x50]\n" + "add x27, x27, #0x60\n" + "str q8, [x24, #0x0]\n" + "str q9, [x24, #0x10]\n" + "str q10, [x24, #0x20]\n" + "str q11, [x24, #0x30]\n" + "str q12, [x24, #0x40]\n" + "str q13, [x24, #0x50]\n" + "str q20, [x23, #0x0]\n" + "str q21, [x23, #0x10]\n" + "str q22, [x23, #0x20]\n" + "str q23, [x23, #0x30]\n" + "str q24, [x23, #0x40]\n" + "str q25, [x23, #0x50]\n" + "129:" // Height 3: Writeback done + "subs x9, x9, #0x18\n" + "bgt 88b\n" + "b 174f\n" + "130:" // Height 4 + "ldr x21, [%x[args_ptr], %[offsetof_output_offset]]\n" + "ldr x27, [%x[args_ptr], %[offsetof_output_ptr]]\n" + "mov x20, #0x10\n" + "ldr x9, [%x[args_ptr], %[offsetof_N]]\n" + "ldr x28, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "madd x20, x21, x20, x27\n" + "str x20, [%x[args_ptr], %[offsetof_output_ptr]]\n" + "131:" // Height 4: Column loop + "cbz x28, 132f\n" + "ldr q8, [x28, #0x0]\n" + "ldr q9, [x28, #0x10]\n" + "ldr q10, [x28, #0x20]\n" + "ldr q11, [x28, #0x30]\n" + "ldr q12, [x28, #0x40]\n" + "ldr q13, [x28, #0x50]\n" + "add x28, x28, #0x60\n" + "zip2 v14.2d, v8.2d, v8.2d\n" + "zip1 v8.2d, v8.2d, v8.2d\n" + "zip2 v15.2d, v9.2d, v9.2d\n" + "zip1 v9.2d, v9.2d, v9.2d\n" + "zip2 v16.2d, v10.2d, v10.2d\n" + "zip1 v10.2d, v10.2d, v10.2d\n" + "zip2 v17.2d, v11.2d, v11.2d\n" + "zip1 v11.2d, v11.2d, v11.2d\n" + "zip2 v18.2d, v12.2d, v12.2d\n" + "zip1 v12.2d, v12.2d, v12.2d\n" + "zip2 v19.2d, v13.2d, v13.2d\n" + "zip1 v13.2d, v13.2d, v13.2d\n" + "mov v20.16b, v8.16b\n" + "mov v26.16b, v14.16b\n" + "mov v21.16b, v9.16b\n" + "mov v27.16b, v15.16b\n" + "mov v22.16b, v10.16b\n" + "mov v28.16b, v16.16b\n" + "mov v23.16b, v11.16b\n" + "mov v29.16b, v17.16b\n" + "mov v24.16b, v12.16b\n" + "mov v30.16b, v18.16b\n" + "mov v25.16b, v13.16b\n" + "mov v31.16b, v19.16b\n" + "b 148f\n" + "132:" // Height 4: no bias + "tbz %x[flags], #0, 147f\n" + "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n" + "cmp x9, #0x18\n" + "add x24, x27, x20, LSL #2\n" + "add x23, x24, x20, LSL #2\n" + "add x22, x23, x20, LSL #2\n" + "bge 145f\n" + "tbz x9, #4, 136f\n" + "ld1 { v9.4s }, [x27], #0x10\n" + "ld1 { v14.4s }, [x24], #0x10\n" + "ld1 { v21.4s }, [x23], #0x10\n" + "ld1 { v26.4s }, [x22], #0x10\n" + "ld1 { v10.4s }, [x27], #0x10\n" + "ld1 { v15.4s }, [x24], #0x10\n" + "ld1 { v22.4s }, [x23], #0x10\n" + "ld1 { v27.4s }, [x22], #0x10\n" + "ld1 { v11.4s }, [x27], #0x10\n" + "ld1 { v16.4s }, [x24], #0x10\n" + "ld1 { v23.4s }, [x23], #0x10\n" + "ld1 { v28.4s }, [x22], #0x10\n" + "ld1 { v12.4s }, [x27], #0x10\n" + "ld1 { v17.4s }, [x24], #0x10\n" + "ld1 { v24.4s }, [x23], #0x10\n" + "ld1 { v29.4s }, [x22], #0x10\n" + "tbz x9, #2, 134f\n" + "ld1 { v13.4s }, [x27], #0x10\n" + "ld1 { v18.4s }, [x24], #0x10\n" + "ld1 { v25.4s }, [x23], #0x10\n" + "ld1 { v30.4s }, [x22], #0x10\n" + "tbz x9, #1, 133f\n" + "ldr d20, [x27], #0x8\n" + "ldr d19, [x24], #0x8\n" + "mov x20, #0x58\n" + "ldr d4, [x23], #0x8\n" + "ldr d31, [x22], #0x8\n" + "tbz x9, #0, 144f\n" + "ld1 { v20.s }[2], [x27]\n" + "ld1 { v19.s }[2], [x24]\n" + "ld1 { v4.s }[2], [x23]\n" + "ld1 { v31.s }[2], [x22]\n" + "b 144f\n" + "133:" // Height 4: Partial accumulate: partial_1_20 + "mov x20, #0x50\n" + "tbz x9, #0, 144f\n" + "ldr s20, [x27, #0x0]\n" + "ldr s19, [x24, #0x0]\n" + "ldr s4, [x23, #0x0]\n" + "ldr s31, [x22, #0x0]\n" + "b 144f\n" + "134:" // Height 4: Partial accumulate: partial_2_16 + "tbz x9, #1, 135f\n" + "ldr d13, [x27], #0x8\n" + "ldr d18, [x24], #0x8\n" + "mov x20, #0x48\n" + "ldr d25, [x23], #0x8\n" + "ldr d30, [x22], #0x8\n" + "tbz x9, #0, 144f\n" + "ld1 { v13.s }[2], [x27]\n" + "ld1 { v18.s }[2], [x24]\n" + "ld1 { v25.s }[2], [x23]\n" + "ld1 { v30.s }[2], [x22]\n" + "b 144f\n" + "135:" // Height 4: Partial accumulate: partial_1_16 + "mov x20, #0x40\n" + "tbz x9, #0, 144f\n" + "ldr s13, [x27, #0x0]\n" + "ldr s18, [x24, #0x0]\n" + "ldr s25, [x23, #0x0]\n" + "ldr s30, [x22, #0x0]\n" + "b 144f\n" + "136:" // Height 4: Partial accumulate: partial_8_0 + "tbz x9, #3, 140f\n" + "ld1 { v9.4s }, [x27], #0x10\n" + "ld1 { v14.4s }, [x24], #0x10\n" + "ld1 { v21.4s }, [x23], #0x10\n" + "ld1 { v26.4s }, [x22], #0x10\n" + "ld1 { v10.4s }, [x27], #0x10\n" + "ld1 { v15.4s }, [x24], #0x10\n" + "ld1 { v22.4s }, [x23], #0x10\n" + "ld1 { v27.4s }, [x22], #0x10\n" + "tbz x9, #2, 138f\n" + "ld1 { v11.4s }, [x27], #0x10\n" + "ld1 { v16.4s }, [x24], #0x10\n" + "ld1 { v23.4s }, [x23], #0x10\n" + "ld1 { v28.4s }, [x22], #0x10\n" + "tbz x9, #1, 137f\n" + "ldr d12, [x27], #0x8\n" + "ldr d17, [x24], #0x8\n" + "mov x20, #0x38\n" + "ldr d24, [x23], #0x8\n" + "ldr d29, [x22], #0x8\n" + "tbz x9, #0, 144f\n" + "ld1 { v12.s }[2], [x27]\n" + "ld1 { v17.s }[2], [x24]\n" + "ld1 { v24.s }[2], [x23]\n" + "ld1 { v29.s }[2], [x22]\n" + "b 144f\n" + "137:" // Height 4: Partial accumulate: partial_1_12 + "mov x20, #0x30\n" + "tbz x9, #0, 144f\n" + "ldr s12, [x27, #0x0]\n" + "ldr s17, [x24, #0x0]\n" + "ldr s24, [x23, #0x0]\n" + "ldr s29, [x22, #0x0]\n" + "b 144f\n" + "138:" // Height 4: Partial accumulate: partial_2_8 + "tbz x9, #1, 139f\n" + "ldr d11, [x27], #0x8\n" + "ldr d16, [x24], #0x8\n" + "mov x20, #0x28\n" + "ldr d23, [x23], #0x8\n" + "ldr d28, [x22], #0x8\n" + "tbz x9, #0, 144f\n" + "ld1 { v11.s }[2], [x27]\n" + "ld1 { v16.s }[2], [x24]\n" + "ld1 { v23.s }[2], [x23]\n" + "ld1 { v28.s }[2], [x22]\n" + "b 144f\n" + "139:" // Height 4: Partial accumulate: partial_1_8 + "mov x20, #0x20\n" + "tbz x9, #0, 144f\n" + "ldr s11, [x27, #0x0]\n" + "ldr s16, [x24, #0x0]\n" + "ldr s23, [x23, #0x0]\n" + "ldr s28, [x22, #0x0]\n" + "b 144f\n" + "140:" // Height 4: Partial accumulate: partial_4_0 + "tbz x9, #2, 142f\n" + "ld1 { v9.4s }, [x27], #0x10\n" + "ld1 { v14.4s }, [x24], #0x10\n" + "ld1 { v21.4s }, [x23], #0x10\n" + "ld1 { v26.4s }, [x22], #0x10\n" + "tbz x9, #1, 141f\n" + "ldr d10, [x27], #0x8\n" + "ldr d15, [x24], #0x8\n" + "mov x20, #0x18\n" + "ldr d22, [x23], #0x8\n" + "ldr d27, [x22], #0x8\n" + "tbz x9, #0, 144f\n" + "ld1 { v10.s }[2], [x27]\n" + "ld1 { v15.s }[2], [x24]\n" + "ld1 { v22.s }[2], [x23]\n" + "ld1 { v27.s }[2], [x22]\n" + "b 144f\n" + "141:" // Height 4: Partial accumulate: partial_1_4 + "mov x20, #0x10\n" + "tbz x9, #0, 144f\n" + "ldr s10, [x27, #0x0]\n" + "ldr s15, [x24, #0x0]\n" + "ldr s22, [x23, #0x0]\n" + "ldr s27, [x22, #0x0]\n" + "b 144f\n" + "142:" // Height 4: Partial accumulate: partial_2_0 + "tbz x9, #1, 143f\n" + "ldr d9, [x27], #0x8\n" + "ldr d14, [x24], #0x8\n" + "mov x20, #0x8\n" + "ldr d21, [x23], #0x8\n" + "ldr d26, [x22], #0x8\n" + "tbz x9, #0, 144f\n" + "ld1 { v9.s }[2], [x27]\n" + "ld1 { v14.s }[2], [x24]\n" + "ld1 { v21.s }[2], [x23]\n" + "ld1 { v26.s }[2], [x22]\n" + "b 144f\n" + "143:" // Height 4: Partial accumulate: partial_1_0 + "ldr s9, [x27, #0x0]\n" + "ldr s14, [x24, #0x0]\n" + "mov x20, #0x0\n" + "ldr s21, [x23, #0x0]\n" + "ldr s26, [x22, #0x0]\n" + "144:" // Height 4: Partial accumulate: Done + "sub x27, x27, x20\n" + "b 146f\n" + "145:" // Height 4: full accumulate + "ldr q9, [x27, #0x0]\n" + "ldr q10, [x27, #0x10]\n" + "ldr q11, [x27, #0x20]\n" + "ldr q12, [x27, #0x30]\n" + "ldr q13, [x27, #0x40]\n" + "ldr q20, [x27, #0x50]\n" + "ldr q14, [x24, #0x0]\n" + "ldr q15, [x24, #0x10]\n" + "ldr q16, [x24, #0x20]\n" + "ldr q17, [x24, #0x30]\n" + "ldr q18, [x24, #0x40]\n" + "ldr q19, [x24, #0x50]\n" + "ldr q21, [x23, #0x0]\n" + "ldr q22, [x23, #0x10]\n" + "ldr q23, [x23, #0x20]\n" + "ldr q24, [x23, #0x30]\n" + "ldr q25, [x23, #0x40]\n" + "ldr q4, [x23, #0x50]\n" + "ldr q26, [x22, #0x0]\n" + "ldr q27, [x22, #0x10]\n" + "ldr q28, [x22, #0x20]\n" + "ldr q29, [x22, #0x30]\n" + "ldr q30, [x22, #0x40]\n" + "ldr q31, [x22, #0x50]\n" + "146:" // Height 4: MMLA fixup + "zip1 v8.2d, v9.2d, v14.2d\n" + "zip2 v14.2d, v9.2d, v14.2d\n" + "zip1 v9.2d, v10.2d, v15.2d\n" + "zip2 v15.2d, v10.2d, v15.2d\n" + "zip1 v10.2d, v11.2d, v16.2d\n" + "zip2 v16.2d, v11.2d, v16.2d\n" + "zip1 v11.2d, v12.2d, v17.2d\n" + "zip2 v17.2d, v12.2d, v17.2d\n" + "zip1 v12.2d, v13.2d, v18.2d\n" + "zip2 v18.2d, v13.2d, v18.2d\n" + "zip1 v13.2d, v20.2d, v19.2d\n" + "zip2 v19.2d, v20.2d, v19.2d\n" + "zip1 v20.2d, v21.2d, v26.2d\n" + "zip2 v26.2d, v21.2d, v26.2d\n" + "zip1 v21.2d, v22.2d, v27.2d\n" + "zip2 v27.2d, v22.2d, v27.2d\n" + "zip1 v22.2d, v23.2d, v28.2d\n" + "zip2 v28.2d, v23.2d, v28.2d\n" + "zip1 v23.2d, v24.2d, v29.2d\n" + "zip2 v29.2d, v24.2d, v29.2d\n" + "zip1 v24.2d, v25.2d, v30.2d\n" + "zip2 v30.2d, v25.2d, v30.2d\n" + "zip1 v25.2d, v4.2d, v31.2d\n" + "zip2 v31.2d, v4.2d, v31.2d\n" + "b 148f\n" + "147:" // Height 4: no accumulate + "movi v8.16b, #0x0\n" + "movi v9.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "movi v15.16b, #0x0\n" + "movi v16.16b, #0x0\n" + "movi v17.16b, #0x0\n" + "movi v18.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "movi v20.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v22.16b, #0x0\n" + "movi v23.16b, #0x0\n" + "movi v24.16b, #0x0\n" + "movi v25.16b, #0x0\n" + "movi v26.16b, #0x0\n" + "movi v27.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "movi v29.16b, #0x0\n" + "movi v30.16b, #0x0\n" + "movi v31.16b, #0x0\n" + "148:" // Height 4: setup done + "mov x26, #0x0\n" + "149:" // Height 4: String loop + "ldr x20, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr x21, [%x[args_ptr], %[offsetof_input_offset]]\n" + "ldr w25, [x20, x26, LSL #0x2]\n" + "tbz %x[flags], #3, 150f\n" + "ldr x20, [%x[input_ptr], x26, LSL #0x3]\n" + "add x20, x20, x21, LSL #3\n" + "ldr x24, [x20, #0x0]\n" + "ldr x23, [x20, #0x8]\n" + "ldr x22, [x20, #0x10]\n" + "ldr x21, [x20, #0x18]\n" + "cbnz x26, 151f\n" + "ldr x20, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x24, x24, x20, LSL #2\n" + "add x23, x23, x20, LSL #2\n" + "add x22, x22, x20, LSL #2\n" + "add x21, x21, x20, LSL #2\n" + "b 151f\n" + "150:" // Height 4: setup direct input + "mov x24, %x[input_ptr]\n" + "add x23, x24, x21, LSL #2\n" + "add x22, x23, x21, LSL #2\n" + "add x21, x22, x21, LSL #2\n" + "151:" // Height 4: input setup done + "cmp x25, #0x4\n" + "blt 154f\n" + "ld1 { v0.4s }, [x24], #0x10\n" + "ld1 { v2.4s }, [x22], #0x10\n" + "cmp x25, #0x8\n" + "ld1 { v1.4s }, [x23], #0x10\n" + "ld1 { v3.4s }, [x21], #0x10\n" + "ldr q4, [x28, #0x0]\n" + "ldr q5, [x28, #0x10]\n" + "ldr q6, [x28, #0x20]\n" + "ldr q7, [x28, #0x30]\n" + "blt 153f\n" + "152:" // Height 4: Multiply loop: Main loop head + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n" + "sub x25, x25, #0x4\n" + "prfm pldl1keep, [x24, #0x80]\n" + "cmp x25, #0x8\n" + "prfm pldl1keep, [x23, #0x80]\n" + "prfm pldl1keep, [x22, #0x80]\n" + "prfm pldl1keep, [x21, #0x80]\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + "ld1 { v1.4s }, [x23], #0x10\n" + ".inst 0x4ea16862 // bfcvtn2 v2.8h, v3.4s\n" + "ld1 { v3.4s }, [x21], #0x10\n" + ".inst 0x6e44ec08 // bfmmla v8.4s, v0.8h, v4.8h\n" + ".inst 0x6e45ec0e // bfmmla v14.4s, v0.8h, v5.8h\n" + ".inst 0x6e46ec09 // bfmmla v9.4s, v0.8h, v6.8h\n" + ".inst 0x6e47ec0f // bfmmla v15.4s, v0.8h, v7.8h\n" + ".inst 0x6e44ec54 // bfmmla v20.4s, v2.8h, v4.8h\n" + "ldr q4, [x28, #0x40]\n" + ".inst 0x6e45ec5a // bfmmla v26.4s, v2.8h, v5.8h\n" + "ldr q5, [x28, #0x50]\n" + ".inst 0x6e46ec55 // bfmmla v21.4s, v2.8h, v6.8h\n" + "ldr q6, [x28, #0x60]\n" + ".inst 0x6e47ec5b // bfmmla v27.4s, v2.8h, v7.8h\n" + "ldr q7, [x28, #0x70]\n" + ".inst 0x6e44ec0a // bfmmla v10.4s, v0.8h, v4.8h\n" + ".inst 0x6e44ec56 // bfmmla v22.4s, v2.8h, v4.8h\n" + "ldr q4, [x28, #0x80]\n" + ".inst 0x6e45ec10 // bfmmla v16.4s, v0.8h, v5.8h\n" + ".inst 0x6e45ec5c // bfmmla v28.4s, v2.8h, v5.8h\n" + "ldr q5, [x28, #0x90]\n" + ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec57 // bfmmla v23.4s, v2.8h, v6.8h\n" + "ldr q6, [x28, #0xa0]\n" + ".inst 0x6e47ec11 // bfmmla v17.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec5d // bfmmla v29.4s, v2.8h, v7.8h\n" + "ldr q7, [x28, #0xb0]\n" + "add x28, x28, #0xc0\n" + ".inst 0x6e44ec0c // bfmmla v12.4s, v0.8h, v4.8h\n" + ".inst 0x6e44ec58 // bfmmla v24.4s, v2.8h, v4.8h\n" + "ldr q4, [x28, #0x0]\n" + ".inst 0x6e45ec12 // bfmmla v18.4s, v0.8h, v5.8h\n" + ".inst 0x6e45ec5e // bfmmla v30.4s, v2.8h, v5.8h\n" + "ldr q5, [x28, #0x10]\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec59 // bfmmla v25.4s, v2.8h, v6.8h\n" + "ldr q6, [x28, #0x20]\n" + ".inst 0x6e47ec13 // bfmmla v19.4s, v0.8h, v7.8h\n" + "ld1 { v0.4s }, [x24], #0x10\n" + ".inst 0x6e47ec5f // bfmmla v31.4s, v2.8h, v7.8h\n" + "ld1 { v2.4s }, [x22], #0x10\n" + "ldr q7, [x28, #0x30]\n" + "bge 152b\n" + "153:" // Height 4: Multiply loop: Single iteration only + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n" + "sub x25, x25, #0x4\n" + "prfm pldl1keep, [x24, #0x80]\n" + "prfm pldl1keep, [x23, #0x80]\n" + "prfm pldl1keep, [x22, #0x80]\n" + "prfm pldl1keep, [x21, #0x80]\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + ".inst 0x4ea16862 // bfcvtn2 v2.8h, v3.4s\n" + ".inst 0x6e44ec08 // bfmmla v8.4s, v0.8h, v4.8h\n" + ".inst 0x6e45ec0e // bfmmla v14.4s, v0.8h, v5.8h\n" + ".inst 0x6e46ec09 // bfmmla v9.4s, v0.8h, v6.8h\n" + ".inst 0x6e47ec0f // bfmmla v15.4s, v0.8h, v7.8h\n" + ".inst 0x6e44ec54 // bfmmla v20.4s, v2.8h, v4.8h\n" + "ldr q4, [x28, #0x40]\n" + ".inst 0x6e45ec5a // bfmmla v26.4s, v2.8h, v5.8h\n" + "ldr q5, [x28, #0x50]\n" + ".inst 0x6e46ec55 // bfmmla v21.4s, v2.8h, v6.8h\n" + "ldr q6, [x28, #0x60]\n" + ".inst 0x6e47ec5b // bfmmla v27.4s, v2.8h, v7.8h\n" + "ldr q7, [x28, #0x70]\n" + ".inst 0x6e44ec0a // bfmmla v10.4s, v0.8h, v4.8h\n" + ".inst 0x6e44ec56 // bfmmla v22.4s, v2.8h, v4.8h\n" + "ldr q4, [x28, #0x80]\n" + ".inst 0x6e45ec10 // bfmmla v16.4s, v0.8h, v5.8h\n" + ".inst 0x6e45ec5c // bfmmla v28.4s, v2.8h, v5.8h\n" + "ldr q5, [x28, #0x90]\n" + ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec57 // bfmmla v23.4s, v2.8h, v6.8h\n" + "ldr q6, [x28, #0xa0]\n" + ".inst 0x6e47ec11 // bfmmla v17.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec5d // bfmmla v29.4s, v2.8h, v7.8h\n" + "ldr q7, [x28, #0xb0]\n" + "add x28, x28, #0xc0\n" + ".inst 0x6e44ec0c // bfmmla v12.4s, v0.8h, v4.8h\n" + ".inst 0x6e44ec58 // bfmmla v24.4s, v2.8h, v4.8h\n" + ".inst 0x6e45ec12 // bfmmla v18.4s, v0.8h, v5.8h\n" + ".inst 0x6e45ec5e // bfmmla v30.4s, v2.8h, v5.8h\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec59 // bfmmla v25.4s, v2.8h, v6.8h\n" + ".inst 0x6e47ec13 // bfmmla v19.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec5f // bfmmla v31.4s, v2.8h, v7.8h\n" + "154:" // Height 4: Multiply loop: Main loop skip + "cbz x25, 157f\n" + "cbz x25, 157f\n" + "tbz x25, #1, 155f\n" + "ldr d0, [x24], #0x8\n" + "ldr d1, [x23], #0x8\n" + "ldr d2, [x22], #0x8\n" + "ldr d3, [x21], #0x8\n" + "tbz x25, #0, 156f\n" + "ld1 { v0.s }[2], [x24]\n" + "ld1 { v1.s }[2], [x23]\n" + "ld1 { v2.s }[2], [x22]\n" + "ld1 { v3.s }[2], [x21]\n" + "b 156f\n" + "155:" // Height 4: Multiply loop: Ragged operand read: partial_1_0 + "ldr s0, [x24, #0x0]\n" + "ldr s1, [x23, #0x0]\n" + "ldr s2, [x22, #0x0]\n" + "ldr s3, [x21, #0x0]\n" + "156:" // Height 4: Multiply loop: Ragged operand read: Done + "ldr q4, [x28, #0x0]\n" + "ldr q5, [x28, #0x10]\n" + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n" + "ldr q6, [x28, #0x20]\n" + "ldr q7, [x28, #0x30]\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + ".inst 0x4ea16862 // bfcvtn2 v2.8h, v3.4s\n" + ".inst 0x6e44ec08 // bfmmla v8.4s, v0.8h, v4.8h\n" + ".inst 0x6e44ec54 // bfmmla v20.4s, v2.8h, v4.8h\n" + "ldr q4, [x28, #0x40]\n" + ".inst 0x6e45ec0e // bfmmla v14.4s, v0.8h, v5.8h\n" + ".inst 0x6e45ec5a // bfmmla v26.4s, v2.8h, v5.8h\n" + "ldr q5, [x28, #0x50]\n" + ".inst 0x6e46ec09 // bfmmla v9.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec55 // bfmmla v21.4s, v2.8h, v6.8h\n" + "ldr q6, [x28, #0x60]\n" + ".inst 0x6e47ec0f // bfmmla v15.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec5b // bfmmla v27.4s, v2.8h, v7.8h\n" + "ldr q7, [x28, #0x70]\n" + ".inst 0x6e44ec0a // bfmmla v10.4s, v0.8h, v4.8h\n" + ".inst 0x6e44ec56 // bfmmla v22.4s, v2.8h, v4.8h\n" + "ldr q4, [x28, #0x80]\n" + ".inst 0x6e45ec10 // bfmmla v16.4s, v0.8h, v5.8h\n" + ".inst 0x6e45ec5c // bfmmla v28.4s, v2.8h, v5.8h\n" + "ldr q5, [x28, #0x90]\n" + ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec57 // bfmmla v23.4s, v2.8h, v6.8h\n" + "ldr q6, [x28, #0xa0]\n" + ".inst 0x6e47ec11 // bfmmla v17.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec5d // bfmmla v29.4s, v2.8h, v7.8h\n" + "ldr q7, [x28, #0xb0]\n" + "add x28, x28, #0xc0\n" + ".inst 0x6e44ec0c // bfmmla v12.4s, v0.8h, v4.8h\n" + ".inst 0x6e44ec58 // bfmmla v24.4s, v2.8h, v4.8h\n" + ".inst 0x6e45ec12 // bfmmla v18.4s, v0.8h, v5.8h\n" + ".inst 0x6e45ec5e // bfmmla v30.4s, v2.8h, v5.8h\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec59 // bfmmla v25.4s, v2.8h, v6.8h\n" + ".inst 0x6e47ec13 // bfmmla v19.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec5f // bfmmla v31.4s, v2.8h, v7.8h\n" + "157:" // Height 4: Multiply loop: No odd multiplies + "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x26, x26, #0x1\n" + "cmp x26, x20\n" + "bne 149b\n" + "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n" + "uzp1 v4.2d, v8.2d, v14.2d\n" + "uzp2 v8.2d, v8.2d, v14.2d\n" + "prfm pstl1keep, [x27, #0x0]\n" + "uzp1 v14.2d, v9.2d, v15.2d\n" + "uzp2 v9.2d, v9.2d, v15.2d\n" + "uzp1 v15.2d, v10.2d, v16.2d\n" + "uzp2 v10.2d, v10.2d, v16.2d\n" + "add x24, x27, x20, LSL #2\n" + "add x23, x24, x20, LSL #2\n" + "add x22, x23, x20, LSL #2\n" + "uzp1 v16.2d, v11.2d, v17.2d\n" + "uzp2 v11.2d, v11.2d, v17.2d\n" + "prfm pstl1keep, [x24, #0x0]\n" + "uzp1 v17.2d, v12.2d, v18.2d\n" + "uzp2 v12.2d, v12.2d, v18.2d\n" + "prfm pstl1keep, [x23, #0x0]\n" + "prfm pstl1keep, [x22, #0x0]\n" + "uzp1 v18.2d, v13.2d, v19.2d\n" + "uzp2 v13.2d, v13.2d, v19.2d\n" + "uzp1 v19.2d, v20.2d, v26.2d\n" + "uzp2 v20.2d, v20.2d, v26.2d\n" + "uzp1 v26.2d, v21.2d, v27.2d\n" + "uzp2 v21.2d, v21.2d, v27.2d\n" + "uzp1 v27.2d, v22.2d, v28.2d\n" + "uzp2 v22.2d, v22.2d, v28.2d\n" + "uzp1 v28.2d, v23.2d, v29.2d\n" + "uzp2 v23.2d, v23.2d, v29.2d\n" + "uzp1 v29.2d, v24.2d, v30.2d\n" + "uzp2 v24.2d, v24.2d, v30.2d\n" + "uzp1 v30.2d, v25.2d, v31.2d\n" + "uzp2 v25.2d, v25.2d, v31.2d\n" + "tbz %x[flags], #1, 158f\n" + "add x21, %x[args_ptr], %[offset_max]\n" + "add x20, %x[args_ptr], %[offset_min]\n" + "ld1r { v1.4s }, [x21]\n" + "ld1r { v0.4s }, [x20]\n" + "fmin v4.4s, v4.4s, v1.4s\n" + "fmin v14.4s, v14.4s, v1.4s\n" + "fmin v15.4s, v15.4s, v1.4s\n" + "fmin v16.4s, v16.4s, v1.4s\n" + "fmin v17.4s, v17.4s, v1.4s\n" + "fmin v18.4s, v18.4s, v1.4s\n" + "fmin v8.4s, v8.4s, v1.4s\n" + "fmin v9.4s, v9.4s, v1.4s\n" + "fmin v10.4s, v10.4s, v1.4s\n" + "fmin v11.4s, v11.4s, v1.4s\n" + "fmin v12.4s, v12.4s, v1.4s\n" + "fmin v13.4s, v13.4s, v1.4s\n" + "fmin v19.4s, v19.4s, v1.4s\n" + "fmin v26.4s, v26.4s, v1.4s\n" + "fmin v27.4s, v27.4s, v1.4s\n" + "fmin v28.4s, v28.4s, v1.4s\n" + "fmin v29.4s, v29.4s, v1.4s\n" + "fmin v30.4s, v30.4s, v1.4s\n" + "fmin v20.4s, v20.4s, v1.4s\n" + "fmin v21.4s, v21.4s, v1.4s\n" + "fmin v22.4s, v22.4s, v1.4s\n" + "fmin v23.4s, v23.4s, v1.4s\n" + "fmin v24.4s, v24.4s, v1.4s\n" + "fmin v25.4s, v25.4s, v1.4s\n" + "fmax v4.4s, v4.4s, v0.4s\n" + "fmax v14.4s, v14.4s, v0.4s\n" + "fmax v15.4s, v15.4s, v0.4s\n" + "fmax v16.4s, v16.4s, v0.4s\n" + "fmax v17.4s, v17.4s, v0.4s\n" + "fmax v18.4s, v18.4s, v0.4s\n" + "fmax v8.4s, v8.4s, v0.4s\n" + "fmax v9.4s, v9.4s, v0.4s\n" + "fmax v10.4s, v10.4s, v0.4s\n" + "fmax v11.4s, v11.4s, v0.4s\n" + "fmax v12.4s, v12.4s, v0.4s\n" + "fmax v13.4s, v13.4s, v0.4s\n" + "fmax v19.4s, v19.4s, v0.4s\n" + "fmax v26.4s, v26.4s, v0.4s\n" + "fmax v27.4s, v27.4s, v0.4s\n" + "fmax v28.4s, v28.4s, v0.4s\n" + "fmax v29.4s, v29.4s, v0.4s\n" + "fmax v30.4s, v30.4s, v0.4s\n" + "fmax v20.4s, v20.4s, v0.4s\n" + "fmax v21.4s, v21.4s, v0.4s\n" + "fmax v22.4s, v22.4s, v0.4s\n" + "fmax v23.4s, v23.4s, v0.4s\n" + "fmax v24.4s, v24.4s, v0.4s\n" + "fmax v25.4s, v25.4s, v0.4s\n" + "158:" // Height 4: No activation + "cmp x9, #0x18\n" + "bge 171f\n" + "tbz x9, #4, 162f\n" + "st1 { v4.4s }, [x27], #0x10\n" + "st1 { v14.4s }, [x27], #0x10\n" + "st1 { v15.4s }, [x27], #0x10\n" + "st1 { v16.4s }, [x27], #0x10\n" + "st1 { v8.4s }, [x24], #0x10\n" + "st1 { v9.4s }, [x24], #0x10\n" + "st1 { v10.4s }, [x24], #0x10\n" + "st1 { v11.4s }, [x24], #0x10\n" + "st1 { v19.4s }, [x23], #0x10\n" + "st1 { v26.4s }, [x23], #0x10\n" + "st1 { v27.4s }, [x23], #0x10\n" + "st1 { v28.4s }, [x23], #0x10\n" + "st1 { v20.4s }, [x22], #0x10\n" + "st1 { v21.4s }, [x22], #0x10\n" + "st1 { v22.4s }, [x22], #0x10\n" + "st1 { v23.4s }, [x22], #0x10\n" + "tbz x9, #2, 160f\n" + "st1 { v17.4s }, [x27], #0x10\n" + "st1 { v12.4s }, [x24], #0x10\n" + "st1 { v29.4s }, [x23], #0x10\n" + "st1 { v24.4s }, [x22], #0x10\n" + "tbz x9, #1, 159f\n" + "str d18, [x27], #0x8\n" + "str d13, [x24], #0x8\n" + "str d30, [x23], #0x8\n" + "str d25, [x22], #0x8\n" + "tbz x9, #0, 170f\n" + "st1 { v18.s }[2], [x27]\n" + "st1 { v13.s }[2], [x24]\n" + "st1 { v30.s }[2], [x23]\n" + "st1 { v25.s }[2], [x22]\n" + "b 170f\n" + "159:" // Height 4: Partial direct writeback: partial_1_20 + "tbz x9, #0, 170f\n" + "str s18, [x27, #0x0]\n" + "str s13, [x24, #0x0]\n" + "str s30, [x23, #0x0]\n" + "str s25, [x22, #0x0]\n" + "b 170f\n" + "160:" // Height 4: Partial direct writeback: partial_2_16 + "tbz x9, #1, 161f\n" + "str d17, [x27], #0x8\n" + "str d12, [x24], #0x8\n" + "str d29, [x23], #0x8\n" + "str d24, [x22], #0x8\n" + "tbz x9, #0, 170f\n" + "st1 { v17.s }[2], [x27]\n" + "st1 { v12.s }[2], [x24]\n" + "st1 { v29.s }[2], [x23]\n" + "st1 { v24.s }[2], [x22]\n" + "b 170f\n" + "161:" // Height 4: Partial direct writeback: partial_1_16 + "tbz x9, #0, 170f\n" + "str s17, [x27, #0x0]\n" + "str s12, [x24, #0x0]\n" + "str s29, [x23, #0x0]\n" + "str s24, [x22, #0x0]\n" + "b 170f\n" + "162:" // Height 4: Partial direct writeback: partial_8_0 + "tbz x9, #3, 166f\n" + "st1 { v4.4s }, [x27], #0x10\n" + "st1 { v14.4s }, [x27], #0x10\n" + "st1 { v8.4s }, [x24], #0x10\n" + "st1 { v9.4s }, [x24], #0x10\n" + "st1 { v19.4s }, [x23], #0x10\n" + "st1 { v26.4s }, [x23], #0x10\n" + "st1 { v20.4s }, [x22], #0x10\n" + "st1 { v21.4s }, [x22], #0x10\n" + "tbz x9, #2, 164f\n" + "st1 { v15.4s }, [x27], #0x10\n" + "st1 { v10.4s }, [x24], #0x10\n" + "st1 { v27.4s }, [x23], #0x10\n" + "st1 { v22.4s }, [x22], #0x10\n" + "tbz x9, #1, 163f\n" + "str d16, [x27], #0x8\n" + "str d11, [x24], #0x8\n" + "str d28, [x23], #0x8\n" + "str d23, [x22], #0x8\n" + "tbz x9, #0, 170f\n" + "st1 { v16.s }[2], [x27]\n" + "st1 { v11.s }[2], [x24]\n" + "st1 { v28.s }[2], [x23]\n" + "st1 { v23.s }[2], [x22]\n" + "b 170f\n" + "163:" // Height 4: Partial direct writeback: partial_1_12 + "tbz x9, #0, 170f\n" + "str s16, [x27, #0x0]\n" + "str s11, [x24, #0x0]\n" + "str s28, [x23, #0x0]\n" + "str s23, [x22, #0x0]\n" + "b 170f\n" + "164:" // Height 4: Partial direct writeback: partial_2_8 + "tbz x9, #1, 165f\n" + "str d15, [x27], #0x8\n" + "str d10, [x24], #0x8\n" + "str d27, [x23], #0x8\n" + "str d22, [x22], #0x8\n" + "tbz x9, #0, 170f\n" + "st1 { v15.s }[2], [x27]\n" + "st1 { v10.s }[2], [x24]\n" + "st1 { v27.s }[2], [x23]\n" + "st1 { v22.s }[2], [x22]\n" + "b 170f\n" + "165:" // Height 4: Partial direct writeback: partial_1_8 + "tbz x9, #0, 170f\n" + "str s15, [x27, #0x0]\n" + "str s10, [x24, #0x0]\n" + "str s27, [x23, #0x0]\n" + "str s22, [x22, #0x0]\n" + "b 170f\n" + "166:" // Height 4: Partial direct writeback: partial_4_0 + "tbz x9, #2, 168f\n" + "st1 { v4.4s }, [x27], #0x10\n" + "st1 { v8.4s }, [x24], #0x10\n" + "st1 { v19.4s }, [x23], #0x10\n" + "st1 { v20.4s }, [x22], #0x10\n" + "tbz x9, #1, 167f\n" + "str d14, [x27], #0x8\n" + "str d9, [x24], #0x8\n" + "str d26, [x23], #0x8\n" + "str d21, [x22], #0x8\n" + "tbz x9, #0, 170f\n" + "st1 { v14.s }[2], [x27]\n" + "st1 { v9.s }[2], [x24]\n" + "st1 { v26.s }[2], [x23]\n" + "st1 { v21.s }[2], [x22]\n" + "b 170f\n" + "167:" // Height 4: Partial direct writeback: partial_1_4 + "tbz x9, #0, 170f\n" + "str s14, [x27, #0x0]\n" + "str s9, [x24, #0x0]\n" + "str s26, [x23, #0x0]\n" + "str s21, [x22, #0x0]\n" + "b 170f\n" + "168:" // Height 4: Partial direct writeback: partial_2_0 + "tbz x9, #1, 169f\n" + "str d4, [x27], #0x8\n" + "str d8, [x24], #0x8\n" + "str d19, [x23], #0x8\n" + "str d20, [x22], #0x8\n" + "tbz x9, #0, 170f\n" + "st1 { v4.s }[2], [x27]\n" + "st1 { v8.s }[2], [x24]\n" + "st1 { v19.s }[2], [x23]\n" + "st1 { v20.s }[2], [x22]\n" + "b 170f\n" + "169:" // Height 4: Partial direct writeback: partial_1_0 + "str s4, [x27, #0x0]\n" + "str s8, [x24, #0x0]\n" + "str s19, [x23, #0x0]\n" + "str s20, [x22, #0x0]\n" + "170:" // Height 4: Partial direct writeback: Done + "b 172f\n" + "171:" // Height 4: Full writeback + "str q4, [x27, #0x0]\n" + "str q14, [x27, #0x10]\n" + "str q15, [x27, #0x20]\n" + "str q16, [x27, #0x30]\n" + "str q17, [x27, #0x40]\n" + "str q18, [x27, #0x50]\n" + "add x27, x27, #0x60\n" + "str q8, [x24, #0x0]\n" + "str q9, [x24, #0x10]\n" + "str q10, [x24, #0x20]\n" + "str q11, [x24, #0x30]\n" + "str q12, [x24, #0x40]\n" + "str q13, [x24, #0x50]\n" + "str q19, [x23, #0x0]\n" + "str q26, [x23, #0x10]\n" + "str q27, [x23, #0x20]\n" + "str q28, [x23, #0x30]\n" + "str q29, [x23, #0x40]\n" + "str q30, [x23, #0x50]\n" + "str q20, [x22, #0x0]\n" + "str q21, [x22, #0x10]\n" + "str q22, [x22, #0x20]\n" + "str q23, [x22, #0x30]\n" + "str q24, [x22, #0x40]\n" + "str q25, [x22, #0x50]\n" + "172:" // Height 4: Writeback done + "subs x9, x9, #0x18\n" + "bgt 131b\n" + "subs %x[m], %x[m], #0x4\n" + "beq 174f\n" + "ldr x21, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 173f\n" + "add x21, x21, #0x4\n" + "str x21, [%x[args_ptr], %[offsetof_input_offset]]\n" + "b 1b\n" + "173:" // Update direct input + "mov x20, #0x10\n" + "madd %x[input_ptr], x20, x21, %x[input_ptr]\n" + "b 1b\n" + "174:" // Exit + : [input_ptr] "+&r"(input_ptr), [m] "+&r"(m) + : [args_ptr] "r"(&ka), [flags] "r"(flags), [offset_max] "I"(offsetof(KernelArgs, maxval)), + [offset_min] "I"(offsetof(KernelArgs, minval)), [offsetof_B_ptr] "I"(offsetof(KernelArgs, B_ptr)), + [offsetof_N] "I"(offsetof(KernelArgs, N)), + [offsetof_input_initial_col] "I"(offsetof(KernelArgs, input_initial_col)), + [offsetof_input_offset] "I"(offsetof(KernelArgs, input_offset)), + [offsetof_num_strings] "I"(offsetof(KernelArgs, num_strings)), + [offsetof_output_offset] "I"(offsetof(KernelArgs, output_offset)), + [offsetof_output_ptr] "I"(offsetof(KernelArgs, output_ptr)), + [offsetof_string_lengths] "I"(offsetof(KernelArgs, string_lengths)) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31", "x9", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_f32_bf16p/kai_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla.h b/kai/ukernels/matmul/matmul_clamp_f32_f32_bf16p/kai_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla.h new file mode 100644 index 0000000000000000000000000000000000000000..033d11c8c329ca934b2fc13a052d50ddaa3bf570 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_f32_bf16p/kai_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla.h @@ -0,0 +1,132 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) +#error This file must be compiled for AArch64, FEAT_BF16. +#else // Architectural features check. + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Micro-kernel dependencies +/// +/// -# kai_rhs_pack_kxn_f16p24x1biasf32_f16_f32_neon to pack the RHS matrix + +/// -------------------------------------------------- + +/// Gets m step value. +/// +/// The starting row index must be divisible by `m_step`. +/// +/// @return The m step value. +size_t kai_get_m_step_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla(void); + +/// Gets n step value. +/// +/// The starting column index must be divisible by `n_step`. +/// +/// @return The n step value. +size_t kai_get_n_step_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla(void); + +/// Gets mr value. +/// +/// This is the number of rows of output block size. +/// +/// @return The mr value. +size_t kai_get_mr_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla(void); + +/// Gets nr value. +/// +/// This is the packing parameter which must be used to pack the RHS matrix. +/// +/// @return The nr value. +size_t kai_get_nr_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla(void); + +/// Gets kr value. +/// +/// This is the packing parameter which must be used to pack the RHS matrix. +/// +/// @return The kr value. +size_t kai_get_kr_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla(void); + +/// Gets sr value. +/// +/// This is the packing parameter which must be used to pack the RHS matrix. +/// +/// @return The sr value. +size_t kai_get_sr_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla(void); + +/// Gets the offset in bytes to the data element in the LHS matrix buffer. +/// +/// @param[in] m_idx Row index. +/// @param[in] stride Row stride in bytes. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_lhs_offset_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla(size_t m_idx, size_t stride); + +/// Gets the offset in bytes to the data element in the packed RHS matrix buffer. +/// +/// @param[in] n_idx Row index. +/// @param[in] k Number of columns. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla(size_t n_idx, size_t k); + +/// Gets the offset in bytes to the data element in the destination matrix buffer. +/// +/// @param[in] m_idx Row index. +/// @param[in] n_idx Column index. +/// @param[in] stride Row stride in bytes. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_dst_offset_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla( + size_t m_idx, size_t n_idx, size_t stride); + +/// Gets the size in bytes of the destination matrix buffer. +/// +/// @param[in] m Number of rows. +/// @param[in] n Number of columns. +/// +/// @return The size in bytes of the destination matrix buffer. +size_t kai_get_dst_size_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla(size_t m, size_t n); + +/// Runs the matrix multiplication microkernel followed by a clamp operation. +/// +/// The pointer of each buffers (LHS, packed RHS and output) needs to be added with offset +/// calculated using the following functions: +/// +/// * LHS: @ref kai_get_lhs_offset_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla. +/// * Packed RHS: @ref kai_get_rhs_packed_offset_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla. +/// * Output: @ref kai_get_dst_offset_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla. +/// +/// @param[in] m Number of output rows to be computed. +/// @param[in] n Number of output columns to be computed. +/// @param[in] k Common dimension of the LHS and RHS operand. +/// @param[in] lhs LHS matrix buffer. +/// @param[in] lhs_stride Row stride in bytes of the LHS matrix. +/// @param[in] rhs_packed Packed RHS buffer. +/// @param[out] dst Output matrix buffer. +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(float) +/// @param[in] clamp_min Minimum value to clamp the final result. +/// @param[in] clamp_max Maximum value to clamp the final result. +void kai_run_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla( + size_t m, size_t n, size_t k, // + const void* lhs, size_t lhs_stride, // + const void* rhs_packed, // + void* dst, size_t dst_stride_row, size_t dst_stride_col, // + float clamp_min, float clamp_max); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_f32_bf16p/matmul_clamp_f32_f32_bf16p_interface.h b/kai/ukernels/matmul/matmul_clamp_f32_f32_bf16p/matmul_clamp_f32_f32_bf16p_interface.h new file mode 100644 index 0000000000000000000000000000000000000000..c642c994168b63f8319c55e4460b57bfb7b06e4f --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_f32_bf16p/matmul_clamp_f32_f32_bf16p_interface.h @@ -0,0 +1,58 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_FP16_SCALAR_ARITHMETIC) || \ + !defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#error This file must be compiled for AArch64, FEAT_FP16. +#else // Architectural features check. + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// All micro-kernels variants of the same type share the same interfaces +// In this case, the micro-kernel type is: matmul_clamp_f32_f32_bf16p + +/// Micro-kernel helper functions ("get" methods) +typedef size_t (*kai_matmul_clamp_f32_f32_bf16p_get_m_step_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_f32_bf16p_get_n_step_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_f32_bf16p_get_mr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_f32_bf16p_get_nr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_f32_bf16p_get_kr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_f32_bf16p_get_sr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_f32_bf16p_get_lhs_offset_func_t)(size_t m_idx, size_t lhs_stride); +typedef size_t (*kai_matmul_clamp_f32_f32_bf16p_get_rhs_packed_offset_func_t)(size_t n_idx, size_t k); +typedef size_t (*kai_matmul_clamp_f32_f32_bf16p_get_dst_offset_func_t)(size_t m_idx, size_t n_idx, size_t dst_stride); +typedef size_t (*kai_matmul_clamp_f32_f32_bf16p_get_dst_size_func_t)(size_t m, size_t n); + +/// Micro-kernel core function ("run" method) +typedef void (*kai_matmul_clamp_f32_f32_bf16p_run_matmul_func_t)( + size_t m, size_t n, size_t k, const void* lhs, size_t lhs_stride, const void* rhs_packed, void* dst, + size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max); + +/// Micro-kernel interface +struct kai_matmul_clamp_f32_f32_bf16p_ukernel { + kai_matmul_clamp_f32_f32_bf16p_get_m_step_func_t get_m_step; + kai_matmul_clamp_f32_f32_bf16p_get_n_step_func_t get_n_step; + kai_matmul_clamp_f32_f32_bf16p_get_mr_func_t get_mr; + kai_matmul_clamp_f32_f32_bf16p_get_nr_func_t get_nr; + kai_matmul_clamp_f32_f32_bf16p_get_nr_func_t get_kr; + kai_matmul_clamp_f32_f32_bf16p_get_sr_func_t get_sr; + kai_matmul_clamp_f32_f32_bf16p_get_lhs_offset_func_t get_lhs_offset; + kai_matmul_clamp_f32_f32_bf16p_get_rhs_packed_offset_func_t get_rhs_packed_offset; + kai_matmul_clamp_f32_f32_bf16p_get_dst_offset_func_t get_dst_offset; + kai_matmul_clamp_f32_f32_bf16p_get_dst_size_func_t get_dst_size; + kai_matmul_clamp_f32_f32_bf16p_run_matmul_func_t run_matmul; +}; + +#ifdef __cplusplus +} +#endif + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_lhs_pack_f32p8x4_bf16_neon.c b/kai/ukernels/matmul/pack/kai_lhs_pack_f32p8x4_bf16_neon.c new file mode 100644 index 0000000000000000000000000000000000000000..32c845f73696a4bfe6ed27dce66e96510444a5ca --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_pack_f32p8x4_bf16_neon.c @@ -0,0 +1,212 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) +#error This file must be compiled for AArch64, FEAT_BF16. +#else // Architectural features check. + +#include +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_mr = 8; +static const size_t kai_kr = 4; +static const size_t kai_sr = 1; + +size_t kai_get_m_step_lhs_pack_f32p8x4_bf16_neon(size_t mr) { + KAI_ASSUME(mr == kai_mr); + KAI_UNUSED(mr); + + return kai_mr; +} + +size_t kai_get_lhs_offset_lhs_pack_f32p8x4_bf16_neon(size_t m_idx, size_t lhs_stride) { + KAI_ASSUME(m_idx % (kai_mr) == 0); + + return m_idx * lhs_stride; +} + +size_t kai_get_lhs_packed_offset_lhs_pack_f32p8x4_bf16_neon(size_t m_idx, size_t k) { + KAI_ASSUME(m_idx % kai_mr == 0); + + return m_idx * kai_roundup(k, kai_kr) * sizeof(bfloat16_t); +} + +size_t kai_get_lhs_packed_size_lhs_pack_f32p8x4_bf16_neon(size_t m, size_t k, size_t mr, size_t kr, size_t sr) { + KAI_ASSUME(mr == kai_mr); + KAI_ASSUME(kr == kai_kr); + KAI_ASSUME(sr == kai_sr); + + KAI_UNUSED(mr); + KAI_UNUSED(kr); + KAI_UNUSED(sr); + + return kai_roundup(m, kai_mr) * kai_roundup(k, kai_kr) * sizeof(bfloat16_t); +} + +void kai_run_lhs_pack_f32p8x4_bf16_neon( + size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride, + void* lhs_packed) { + KAI_ASSUME(mr == kai_mr); + KAI_ASSUME(kr == kai_kr); + KAI_ASSUME(sr == kai_sr); + KAI_ASSUME(lhs != NULL); + KAI_ASSUME(lhs_packed != NULL); + + KAI_ASSUME(m_idx_start == 0); + + const size_t block_height = kai_mr; + const size_t row_offset = 0; + + const void* in[block_height]; + + for (size_t block_y = 0; block_y < m; block_y += block_height) { + const size_t height = KAI_MIN(m - block_y, block_height); + void* out = (char*)lhs_packed + block_y * kai_roundup(k, kr) * sizeof(bfloat16_t); + size_t width = k; + + for (size_t y = 0; y < height; y++) { + in[y] = (char*)lhs + (block_y + y) * lhs_stride; + } + + __asm__ __volatile__( + "ldr x28, [%x[in], #0x0]\n" + "ldr x27, [%x[in], #0x8]\n" + "cmp %x[height], #0x8\n" + "ldr x26, [%x[in], #0x10]\n" + "ldr x25, [%x[in], #0x18]\n" + "ldr x24, [%x[in], #0x20]\n" + "ldr x23, [%x[in], #0x28]\n" + "ldr x22, [%x[in], #0x30]\n" + "ldr x21, [%x[in], #0x38]\n" + "add x28, x28, %x[row_offset], LSL #2\n" + "add x27, x27, %x[row_offset], LSL #2\n" + "add x26, x26, %x[row_offset], LSL #2\n" + "add x25, x25, %x[row_offset], LSL #2\n" + "add x24, x24, %x[row_offset], LSL #2\n" + "add x23, x23, %x[row_offset], LSL #2\n" + "add x22, x22, %x[row_offset], LSL #2\n" + "add x21, x21, %x[row_offset], LSL #2\n" + "beq 1f\n" + "cmp %x[height], #0x2\n" + "mov x21, x28\n" + "csel x27, x27, x28, GE\n" + "csel x26, x26, x28, GT\n" + "cmp %x[height], #0x4\n" + "csel x25, x25, x28, GE\n" + "csel x24, x24, x28, GT\n" + "cmp %x[height], #0x6\n" + "csel x23, x23, x28, GE\n" + "csel x22, x22, x28, GT\n" + "1:" // no_pointer_adj + "cmp %x[width], #0x4\n" + "prfm pldl1keep, [x28, #0x0]\n" + "prfm pldl1keep, [x27, #0x0]\n" + "prfm pldl1keep, [x26, #0x0]\n" + "prfm pldl1keep, [x25, #0x0]\n" + "prfm pldl1keep, [x24, #0x0]\n" + "prfm pldl1keep, [x23, #0x0]\n" + "prfm pldl1keep, [x22, #0x0]\n" + "prfm pldl1keep, [x21, #0x0]\n" + "prfm pldl1keep, [x28, #0x40]\n" + "prfm pldl1keep, [x27, #0x40]\n" + "prfm pldl1keep, [x26, #0x40]\n" + "prfm pldl1keep, [x25, #0x40]\n" + "prfm pldl1keep, [x24, #0x40]\n" + "prfm pldl1keep, [x23, #0x40]\n" + "prfm pldl1keep, [x22, #0x40]\n" + "prfm pldl1keep, [x21, #0x40]\n" + "blt 3f\n" + "2:" // Main loop head + "ldr q19, [x28], #0x10\n" + "ldr q18, [x26], #0x10\n" + "subs %x[width], %x[width], #0x4\n" + "ldr q17, [x24], #0x10\n" + "ldr q16, [x22], #0x10\n" + "cmp %x[width], #0x4\n" + "ldr q23, [x27], #0x10\n" + "ldr q22, [x25], #0x10\n" + "ldr q21, [x23], #0x10\n" + "ldr q20, [x21], #0x10\n" + ".inst 0x0ea16a73 // bfcvtn v19.4h, v19.4s\n" + ".inst 0x0ea16a52 // bfcvtn v18.4h, v18.4s\n" + ".inst 0x0ea16a31 // bfcvtn v17.4h, v17.4s\n" + ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n" + "prfm pldl1keep, [x28, #0x70]\n" + "prfm pldl1keep, [x27, #0x70]\n" + "prfm pldl1keep, [x26, #0x70]\n" + "prfm pldl1keep, [x25, #0x70]\n" + "prfm pldl1keep, [x24, #0x70]\n" + "prfm pldl1keep, [x23, #0x70]\n" + ".inst 0x4ea16af3 // bfcvtn2 v19.8h, v23.4s\n" + ".inst 0x4ea16ad2 // bfcvtn2 v18.8h, v22.4s\n" + "prfm pldl1keep, [x22, #0x70]\n" + "prfm pldl1keep, [x21, #0x70]\n" + ".inst 0x4ea16ab1 // bfcvtn2 v17.8h, v21.4s\n" + ".inst 0x4ea16a90 // bfcvtn2 v16.8h, v20.4s\n" + "str q19, [%x[out_ptr], #0x0]\n" + "str q18, [%x[out_ptr], #0x10]\n" + "str q17, [%x[out_ptr], #0x20]\n" + "str q16, [%x[out_ptr], #0x30]\n" + "add %x[out_ptr], %x[out_ptr], #0x40\n" + "bge 2b\n" + "3:" // Main loop skip + "cbz %x[width], 6f\n" + "tbz %x[width], #1, 4f\n" + "ldr d19, [x28], #0x8\n" + "ldr d23, [x27], #0x8\n" + "mov x20, #0x1\n" + "ldr d18, [x26], #0x8\n" + "ldr d22, [x25], #0x8\n" + "ldr d17, [x24], #0x8\n" + "ldr d21, [x23], #0x8\n" + "ldr d16, [x22], #0x8\n" + "ldr d20, [x21], #0x8\n" + "tbz %x[width], #0, 5f\n" + "ld1 { v19.s }[2], [x28]\n" + "ld1 { v23.s }[2], [x27]\n" + "ld1 { v18.s }[2], [x26]\n" + "ld1 { v22.s }[2], [x25]\n" + "ld1 { v17.s }[2], [x24]\n" + "ld1 { v21.s }[2], [x23]\n" + "ld1 { v16.s }[2], [x22]\n" + "ld1 { v20.s }[2], [x21]\n" + "b 5f\n" + "4:" // odd_loads_1_0 + "ldr s19, [x28, #0x0]\n" + "ldr s23, [x27, #0x0]\n" + "mov x20, #0x1\n" + "ldr s18, [x26, #0x0]\n" + "ldr s22, [x25, #0x0]\n" + "ldr s17, [x24, #0x0]\n" + "ldr s21, [x23, #0x0]\n" + "ldr s16, [x22, #0x0]\n" + "ldr s20, [x21, #0x0]\n" + "5:" // Odd load end + ".inst 0x0ea16a73 // bfcvtn v19.4h, v19.4s\n" + ".inst 0x0ea16a52 // bfcvtn v18.4h, v18.4s\n" + ".inst 0x0ea16a31 // bfcvtn v17.4h, v17.4s\n" + ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n" + ".inst 0x4ea16af3 // bfcvtn2 v19.8h, v23.4s\n" + ".inst 0x4ea16ad2 // bfcvtn2 v18.8h, v22.4s\n" + ".inst 0x4ea16ab1 // bfcvtn2 v17.8h, v21.4s\n" + ".inst 0x4ea16a90 // bfcvtn2 v16.8h, v20.4s\n" + "str q19, [%x[out_ptr], #0x0]\n" + "str q18, [%x[out_ptr], #0x10]\n" + "str q17, [%x[out_ptr], #0x20]\n" + "str q16, [%x[out_ptr], #0x30]\n" + "add %x[out_ptr], %x[out_ptr], #0x40\n" + "6:" // Odds skip + : [out_ptr] "+&r"(out), [width] "+&r"(width) + : [height] "r"(height), [in] "r"(in), [row_offset] "r"(row_offset) + : "cc", "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "x20", "x21", "x22", "x23", "x24", + "x25", "x26", "x27", "x28"); + } +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_lhs_pack_f32p8x4_bf16_neon.h b/kai/ukernels/matmul/pack/kai_lhs_pack_f32p8x4_bf16_neon.h new file mode 100644 index 0000000000000000000000000000000000000000..dd47ba8818afb7dcb12a3ba8f3ca7c45dc4b601e --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_pack_f32p8x4_bf16_neon.h @@ -0,0 +1,31 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +#include +#include + +#include "kai/kai_common.h" + +size_t kai_get_m_step_lhs_pack_f32p8x4_bf16_neon(size_t mr); + +size_t kai_get_lhs_offset_lhs_pack_f32p8x4_bf16_neon(size_t m_idx, size_t lhs_stride); + +size_t kai_get_lhs_packed_offset_lhs_pack_f32p8x4_bf16_neon(size_t m_idx, size_t k); + +size_t kai_get_lhs_packed_size_lhs_pack_f32p8x4_bf16_neon(size_t m, size_t k, size_t mr, size_t kr, size_t sr); + +void kai_run_lhs_pack_f32p8x4_bf16_neon( + size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride, + void* lhs_packed); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon.c b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon.c new file mode 100644 index 0000000000000000000000000000000000000000..943067e573be58311c5ac98a7e089a9eef200b9e --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon.c @@ -0,0 +1,461 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if !defined(__aarch64__) +#error This file must be compiled for AArch64. +#else // Architectural features check. + +#include +#include +#include +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_nr = 12; +static const size_t kai_kr = 4; + +size_t kai_get_n_step_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon(void) { + return kai_nr; +} + +size_t kai_get_rhs_offset_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon(size_t n_idx) { + KAI_ASSUME(n_idx % kai_nr == 0); + + return n_idx * sizeof(float); +} + +size_t kai_get_bias_offset_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon(size_t n_idx) { + return n_idx * sizeof(uint32_t); +} + +size_t kai_get_rhs_packed_offset_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon(size_t n_idx, size_t k) { + KAI_ASSUME(n_idx % kai_nr == 0); + + return n_idx * (sizeof(uint32_t) + kai_roundup(k, kai_kr) * sizeof(bfloat16_t)); +} + +size_t kai_get_rhs_packed_size_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon(size_t n, size_t k) { + return kai_get_rhs_packed_offset_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon(kai_roundup(n, kai_nr), k); +} + +void kai_run_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, + const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params) { + KAI_ASSUME(num_groups == 1); + KAI_ASSUME(nr == kai_nr); + KAI_ASSUME(kr == kai_kr); + KAI_ASSUME(sr == 1); + KAI_ASSUME(rhs != NULL); + KAI_ASSUME(bias != NULL); + KAI_ASSUME(scale == NULL); + KAI_ASSUME(rhs_packed != NULL); + KAI_ASSUME(extra_bytes == 0); + KAI_ASSUME(params == NULL); + + size_t height = k; + const size_t width = n; + const void* in = rhs; + void* out = rhs_packed; + const size_t in_stride = rhs_stride; + float* pad_row = (float*)alloca(width * sizeof(float)); + + if (height % 4) { + memset(pad_row, 0, width * sizeof(float)); + } + + size_t out_stride = kai_nr * kai_roundup(height, 4) * sizeof(bfloat16_t) + kai_nr * sizeof(uint32_t); + + __asm__ __volatile__( + "mov x22, %x[width]\n" + "mov x21, %x[out]\n" + "cmp x22, #0xc\n" + "blt 2f\n" + "1:" // Bias: Full loop + "ldr q16, [%x[bias], #0x0]\n" + "ldr q26, [%x[bias], #0x10]\n" + "sub x22, x22, #0xc\n" + "ldr q8, [%x[bias], #0x20]\n" + "cmp x22, #0xc\n" + "add %x[bias], %x[bias], #0x30\n" + "str q16, [x21, #0x0]\n" + "str q26, [x21, #0x10]\n" + "str q8, [x21, #0x20]\n" + "add x21, x21, %x[out_stride]\n" + "bge 1b\n" + "cbz x22, 3f\n" + "2:" // Bias: Tail loop + "ldr w20, [%x[bias], #0x0]\n" + "sub x22, x22, #0x1\n" + "add %x[bias], %x[bias], #0x4\n" + "cmp x22, #0x0\n" + "str w20, [x21]\n" + "add x21, x21, #0x4\n" + "bgt 2b\n" + "3:" // Bias: Done + "cmp %x[height], #0x8\n" + "add %x[out], %x[out], #0x30\n" + "blt 12f\n" + "4:" // Main row loop: Head + "mov x9, %x[in]\n" + "mov x28, %x[width]\n" + "mov x27, %x[out]\n" + "sub %x[height], %x[height], #0x8\n" + "add x26, x9, %x[in_stride]\n" + "add x25, x26, %x[in_stride]\n" + "add x24, x25, %x[in_stride]\n" + "cmp x28, #0xc\n" + "add x23, x24, %x[in_stride]\n" + "add x22, x23, %x[in_stride]\n" + "add x21, x22, %x[in_stride]\n" + "add x20, x21, %x[in_stride]\n" + "add %x[in], x20, %x[in_stride]\n" + "blt 6f\n" + "5:" // Main row loop: Column loop + "ldr q28, [x9], #0x10\n" + "ldr q27, [x26], #0x10\n" + "sub x28, x28, #0xc\n" + "ldr q11, [x25], #0x10\n" + "ldr q5, [x24], #0x10\n" + "cmp x28, #0xc\n" + "ldr q14, [x23], #0x10\n" + "ldr q6, [x22], #0x10\n" + "ldr q2, [x21], #0x10\n" + "ldr q18, [x20], #0x10\n" + "ldr q1, [x9], #0x10\n" + "ldr q7, [x26], #0x10\n" + "zip1 v15.4s, v28.4s, v11.4s\n" + "zip1 v8.4s, v27.4s, v5.4s\n" + "ldr q3, [x25], #0x10\n" + "ldr q23, [x24], #0x10\n" + "zip2 v17.4s, v28.4s, v11.4s\n" + "zip2 v27.4s, v27.4s, v5.4s\n" + "ldr q5, [x23], #0x10\n" + "ldr q30, [x22], #0x10\n" + "zip1 v26.4s, v14.4s, v2.4s\n" + "zip1 v31.4s, v6.4s, v18.4s\n" + "ldr q20, [x21], #0x10\n" + "ldr q16, [x20], #0x10\n" + "zip2 v12.4s, v14.4s, v2.4s\n" + "zip2 v24.4s, v6.4s, v18.4s\n" + "ldr q29, [x9], #0x10\n" + "ldr q6, [x26], #0x10\n" + "zip1 v18.4s, v1.4s, v3.4s\n" + "zip1 v4.4s, v7.4s, v23.4s\n" + "ldr q22, [x25], #0x10\n" + "ldr q0, [x24], #0x10\n" + "zip2 v3.4s, v1.4s, v3.4s\n" + "zip2 v1.4s, v7.4s, v23.4s\n" + "ldr q2, [x23], #0x10\n" + "ldr q10, [x22], #0x10\n" + "zip1 v28.4s, v5.4s, v20.4s\n" + "zip1 v14.4s, v30.4s, v16.4s\n" + "ldr q9, [x21], #0x10\n" + "ldr q23, [x20], #0x10\n" + "zip2 v13.4s, v5.4s, v20.4s\n" + "zip2 v30.4s, v30.4s, v16.4s\n" + "zip1 v16.4s, v29.4s, v22.4s\n" + "zip1 v5.4s, v6.4s, v0.4s\n" + "zip2 v22.4s, v29.4s, v22.4s\n" + "zip2 v0.4s, v6.4s, v0.4s\n" + "zip1 v7.4s, v2.4s, v9.4s\n" + "zip1 v19.4s, v10.4s, v23.4s\n" + "zip2 v21.4s, v2.4s, v9.4s\n" + "zip2 v25.4s, v10.4s, v23.4s\n" + "zip1 v11.4s, v15.4s, v8.4s\n" + "zip1 v9.4s, v17.4s, v27.4s\n" + "zip1 v6.4s, v18.4s, v4.4s\n" + "zip1 v2.4s, v3.4s, v1.4s\n" + "zip1 v29.4s, v16.4s, v5.4s\n" + "zip1 v20.4s, v22.4s, v0.4s\n" + "zip1 v10.4s, v26.4s, v31.4s\n" + "zip1 v23.4s, v12.4s, v24.4s\n" + ".inst 0x0ea1696b // bfcvtn v11.4h, v11.4s\n" + "zip2 v8.4s, v15.4s, v8.4s\n" + "zip1 v15.4s, v28.4s, v14.4s\n" + ".inst 0x0ea16929 // bfcvtn v9.4h, v9.4s\n" + "zip2 v27.4s, v17.4s, v27.4s\n" + "zip1 v17.4s, v13.4s, v30.4s\n" + ".inst 0x0ea168c6 // bfcvtn v6.4h, v6.4s\n" + "zip2 v4.4s, v18.4s, v4.4s\n" + "zip1 v18.4s, v7.4s, v19.4s\n" + ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n" + "zip2 v1.4s, v3.4s, v1.4s\n" + "zip1 v3.4s, v21.4s, v25.4s\n" + ".inst 0x0ea16bbd // bfcvtn v29.4h, v29.4s\n" + "zip2 v5.4s, v16.4s, v5.4s\n" + ".inst 0x0ea16a94 // bfcvtn v20.4h, v20.4s\n" + "zip2 v0.4s, v22.4s, v0.4s\n" + ".inst 0x0ea16956 // bfcvtn v22.4h, v10.4s\n" + "zip2 v31.4s, v26.4s, v31.4s\n" + ".inst 0x0ea16aea // bfcvtn v10.4h, v23.4s\n" + "zip2 v26.4s, v12.4s, v24.4s\n" + ".inst 0x0ea169ef // bfcvtn v15.4h, v15.4s\n" + "zip2 v12.4s, v28.4s, v14.4s\n" + ".inst 0x0ea16a2e // bfcvtn v14.4h, v17.4s\n" + "zip2 v24.4s, v13.4s, v30.4s\n" + ".inst 0x0ea16a57 // bfcvtn v23.4h, v18.4s\n" + "zip2 v18.4s, v7.4s, v19.4s\n" + ".inst 0x0ea16871 // bfcvtn v17.4h, v3.4s\n" + "zip2 v16.4s, v21.4s, v25.4s\n" + ".inst 0x4ea1690b // bfcvtn2 v11.8h, v8.4s\n" + ".inst 0x4ea16b69 // bfcvtn2 v9.8h, v27.4s\n" + ".inst 0x4ea16886 // bfcvtn2 v6.8h, v4.4s\n" + ".inst 0x4ea16822 // bfcvtn2 v2.8h, v1.4s\n" + ".inst 0x4ea168bd // bfcvtn2 v29.8h, v5.4s\n" + ".inst 0x4ea16814 // bfcvtn2 v20.8h, v0.4s\n" + ".inst 0x4ea16bf6 // bfcvtn2 v22.8h, v31.4s\n" + ".inst 0x4ea16b4a // bfcvtn2 v10.8h, v26.4s\n" + "str q11, [x27, #0x0]\n" + ".inst 0x4ea1698f // bfcvtn2 v15.8h, v12.4s\n" + ".inst 0x4ea16b0e // bfcvtn2 v14.8h, v24.4s\n" + "str q9, [x27, #0x10]\n" + ".inst 0x4ea16a57 // bfcvtn2 v23.8h, v18.4s\n" + ".inst 0x4ea16a11 // bfcvtn2 v17.8h, v16.4s\n" + "str q6, [x27, #0x20]\n" + "str q2, [x27, #0x30]\n" + "str q29, [x27, #0x40]\n" + "str q20, [x27, #0x50]\n" + "str q22, [x27, #0x60]\n" + "str q10, [x27, #0x70]\n" + "str q15, [x27, #0x80]\n" + "str q14, [x27, #0x90]\n" + "str q23, [x27, #0xa0]\n" + "str q17, [x27, #0xb0]\n" + "add x27, x27, %x[out_stride]\n" + "bge 5b\n" + "6:" // Main row loop: Column loop skip + "cbz x28, 11f\n" + "cmp x28, #0x4\n" + "movi v16.16b, #0x0\n" + "str q16, [x27, #0x0]\n" + "str q16, [x27, #0x10]\n" + "str q16, [x27, #0x20]\n" + "str q16, [x27, #0x30]\n" + "str q16, [x27, #0x40]\n" + "str q16, [x27, #0x50]\n" + "str q16, [x27, #0x60]\n" + "str q16, [x27, #0x70]\n" + "str q16, [x27, #0x80]\n" + "str q16, [x27, #0x90]\n" + "str q16, [x27, #0xa0]\n" + "str q16, [x27, #0xb0]\n" + "blt 8f\n" + "7:" // Main row loop: width 4 loop: loop + "ldr q25, [x9], #0x10\n" + "ldr q24, [x26], #0x10\n" + "sub x28, x28, #0x4\n" + "ldr q21, [x25], #0x10\n" + "ldr q20, [x24], #0x10\n" + "cmp x28, #0x4\n" + "ldr q23, [x23], #0x10\n" + "ldr q19, [x22], #0x10\n" + "ldr q18, [x21], #0x10\n" + "ldr q17, [x20], #0x10\n" + "zip1 v22.4s, v25.4s, v21.4s\n" + "zip1 v16.4s, v24.4s, v20.4s\n" + "zip2 v21.4s, v25.4s, v21.4s\n" + "zip2 v20.4s, v24.4s, v20.4s\n" + "zip1 v27.4s, v23.4s, v18.4s\n" + "zip1 v26.4s, v19.4s, v17.4s\n" + "zip2 v25.4s, v23.4s, v18.4s\n" + "zip2 v24.4s, v19.4s, v17.4s\n" + "zip1 v19.4s, v22.4s, v16.4s\n" + "zip1 v18.4s, v21.4s, v20.4s\n" + "zip1 v17.4s, v27.4s, v26.4s\n" + "zip2 v23.4s, v22.4s, v16.4s\n" + "zip1 v16.4s, v25.4s, v24.4s\n" + "zip2 v22.4s, v21.4s, v20.4s\n" + ".inst 0x0ea16a75 // bfcvtn v21.4h, v19.4s\n" + ".inst 0x0ea16a54 // bfcvtn v20.4h, v18.4s\n" + ".inst 0x0ea16a33 // bfcvtn v19.4h, v17.4s\n" + "zip2 v18.4s, v27.4s, v26.4s\n" + ".inst 0x0ea16a11 // bfcvtn v17.4h, v16.4s\n" + "zip2 v16.4s, v25.4s, v24.4s\n" + ".inst 0x4ea16af5 // bfcvtn2 v21.8h, v23.4s\n" + ".inst 0x4ea16ad4 // bfcvtn2 v20.8h, v22.4s\n" + ".inst 0x4ea16a53 // bfcvtn2 v19.8h, v18.4s\n" + ".inst 0x4ea16a11 // bfcvtn2 v17.8h, v16.4s\n" + "str q21, [x27, #0x0]\n" + "str q20, [x27, #0x10]\n" + "str q19, [x27, #0x60]\n" + "str q17, [x27, #0x70]\n" + "add x27, x27, #0x20\n" + "bge 7b\n" + "8:" // Main row loop: width 4 loop: skip + "cmp x28, #0x1\n" + "blt 10f\n" + "9:" // Main row loop: width 1 loop: loop + "ldr s23, [x9], #0x4\n" + "ldr s22, [x26], #0x4\n" + "sub x28, x28, #0x1\n" + "ldr s19, [x25], #0x4\n" + "ldr s17, [x24], #0x4\n" + "cmp x28, #0x1\n" + "ldr s21, [x23], #0x4\n" + "ldr s20, [x22], #0x4\n" + "ldr s18, [x21], #0x4\n" + "ldr s16, [x20], #0x4\n" + "zip1 v19.4s, v23.4s, v19.4s\n" + "zip1 v17.4s, v22.4s, v17.4s\n" + "zip1 v18.4s, v21.4s, v18.4s\n" + "zip1 v16.4s, v20.4s, v16.4s\n" + "zip1 v17.4s, v19.4s, v17.4s\n" + "zip1 v16.4s, v18.4s, v16.4s\n" + ".inst 0x0ea16a31 // bfcvtn v17.4h, v17.4s\n" + ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n" + "str d17, [x27, #0x0]\n" + "str d16, [x27, #0x60]\n" + "add x27, x27, #0x8\n" + "bge 9b\n" + "10:" // Main row loop: width 1 loop: skip + "11:" // Main row loop: odd col skip + "cmp %x[height], #0x8\n" + "add %x[out], %x[out], #0xc0\n" + "bge 4b\n" + "cbz %x[height], 21f\n" + "12:" // Main loop skip + "13:" // Tail row loop: Head + "mov x9, %x[in]\n" + "mov x20, %x[width]\n" + "cmp %x[height], #0x3\n" + "mov x27, %x[out]\n" + "add x26, x9, %x[in_stride]\n" + "add x25, x26, %x[in_stride]\n" + "add x24, x25, %x[in_stride]\n" + "csel x25, x25, %x[pad_row], GE\n" + "add %x[in], x24, %x[in_stride]\n" + "csel x24, x24, %x[pad_row], GT\n" + "cmp %x[height], #0x1\n" + "sub %x[height], %x[height], #0x4\n" + "csel x26, x26, %x[pad_row], GT\n" + "cmp x20, #0xc\n" + "blt 15f\n" + "14:" // Tail row loop: Column loop + "ldr q24, [x9], #0x10\n" + "ldr q23, [x26], #0x10\n" + "sub x20, x20, #0xc\n" + "ldr q22, [x25], #0x10\n" + "ldr q16, [x24], #0x10\n" + "cmp x20, #0xc\n" + "ldr q28, [x9], #0x10\n" + "ldr q27, [x26], #0x10\n" + "ldr q21, [x25], #0x10\n" + "ldr q20, [x24], #0x10\n" + "ldr q19, [x9], #0x10\n" + "zip1 v26.4s, v24.4s, v22.4s\n" + "zip1 v25.4s, v23.4s, v16.4s\n" + "ldr q18, [x26], #0x10\n" + "ldr q17, [x25], #0x10\n" + "zip2 v24.4s, v24.4s, v22.4s\n" + "zip2 v23.4s, v23.4s, v16.4s\n" + "ldr q16, [x24], #0x10\n" + "zip1 v2.4s, v28.4s, v21.4s\n" + "zip1 v22.4s, v27.4s, v20.4s\n" + "zip2 v1.4s, v28.4s, v21.4s\n" + "zip2 v0.4s, v27.4s, v20.4s\n" + "zip1 v31.4s, v19.4s, v17.4s\n" + "zip1 v30.4s, v18.4s, v16.4s\n" + "zip2 v29.4s, v19.4s, v17.4s\n" + "zip2 v28.4s, v18.4s, v16.4s\n" + "zip1 v21.4s, v26.4s, v25.4s\n" + "zip1 v20.4s, v24.4s, v23.4s\n" + "zip1 v19.4s, v2.4s, v22.4s\n" + "zip1 v18.4s, v1.4s, v0.4s\n" + "zip1 v17.4s, v31.4s, v30.4s\n" + "zip1 v16.4s, v29.4s, v28.4s\n" + ".inst 0x0ea16abb // bfcvtn v27.4h, v21.4s\n" + "zip2 v26.4s, v26.4s, v25.4s\n" + ".inst 0x0ea16a99 // bfcvtn v25.4h, v20.4s\n" + "zip2 v24.4s, v24.4s, v23.4s\n" + ".inst 0x0ea16a77 // bfcvtn v23.4h, v19.4s\n" + "zip2 v22.4s, v2.4s, v22.4s\n" + ".inst 0x0ea16a55 // bfcvtn v21.4h, v18.4s\n" + "zip2 v20.4s, v1.4s, v0.4s\n" + ".inst 0x0ea16a33 // bfcvtn v19.4h, v17.4s\n" + "zip2 v18.4s, v31.4s, v30.4s\n" + ".inst 0x0ea16a11 // bfcvtn v17.4h, v16.4s\n" + "zip2 v16.4s, v29.4s, v28.4s\n" + ".inst 0x4ea16b5b // bfcvtn2 v27.8h, v26.4s\n" + ".inst 0x4ea16b19 // bfcvtn2 v25.8h, v24.4s\n" + ".inst 0x4ea16ad7 // bfcvtn2 v23.8h, v22.4s\n" + ".inst 0x4ea16a95 // bfcvtn2 v21.8h, v20.4s\n" + ".inst 0x4ea16a53 // bfcvtn2 v19.8h, v18.4s\n" + ".inst 0x4ea16a11 // bfcvtn2 v17.8h, v16.4s\n" + "str q27, [x27, #0x0]\n" + "str q25, [x27, #0x10]\n" + "str q23, [x27, #0x20]\n" + "str q21, [x27, #0x30]\n" + "str q19, [x27, #0x40]\n" + "str q17, [x27, #0x50]\n" + "add x27, x27, %x[out_stride]\n" + "bge 14b\n" + "15:" // Tail row loop: Column loop skip + "cbz x20, 20f\n" + "cmp x20, #0x4\n" + "movi v16.16b, #0x0\n" + "str q16, [x27, #0x0]\n" + "str q16, [x27, #0x10]\n" + "str q16, [x27, #0x20]\n" + "str q16, [x27, #0x30]\n" + "str q16, [x27, #0x40]\n" + "str q16, [x27, #0x50]\n" + "blt 17f\n" + "16:" // Tail row loop: width 4 loop: loop + "ldr q21, [x9], #0x10\n" + "ldr q20, [x26], #0x10\n" + "sub x20, x20, #0x4\n" + "ldr q19, [x25], #0x10\n" + "ldr q17, [x24], #0x10\n" + "cmp x20, #0x4\n" + "zip1 v18.4s, v21.4s, v19.4s\n" + "zip1 v16.4s, v20.4s, v17.4s\n" + "zip2 v21.4s, v21.4s, v19.4s\n" + "zip2 v20.4s, v20.4s, v17.4s\n" + "zip1 v17.4s, v18.4s, v16.4s\n" + "zip2 v19.4s, v18.4s, v16.4s\n" + "zip1 v16.4s, v21.4s, v20.4s\n" + ".inst 0x0ea16a32 // bfcvtn v18.4h, v17.4s\n" + "zip2 v17.4s, v21.4s, v20.4s\n" + ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n" + ".inst 0x4ea16a72 // bfcvtn2 v18.8h, v19.4s\n" + ".inst 0x4ea16a30 // bfcvtn2 v16.8h, v17.4s\n" + "str q18, [x27, #0x0]\n" + "str q16, [x27, #0x10]\n" + "add x27, x27, #0x20\n" + "bge 16b\n" + "17:" // Tail row loop: width 4 loop: skip + "cmp x20, #0x1\n" + "blt 19f\n" + "18:" // Tail row loop: width 1 loop: loop + "ldr s19, [x9], #0x4\n" + "ldr s18, [x26], #0x4\n" + "sub x20, x20, #0x1\n" + "ldr s17, [x25], #0x4\n" + "ldr s16, [x24], #0x4\n" + "cmp x20, #0x1\n" + "zip1 v17.4s, v19.4s, v17.4s\n" + "zip1 v16.4s, v18.4s, v16.4s\n" + "zip1 v16.4s, v17.4s, v16.4s\n" + ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n" + "str d16, [x27, #0x0]\n" + "add x27, x27, #0x8\n" + "bge 18b\n" + "19:" // Tail row loop: width 1 loop: skip + "20:" // Tail row loop: odd col skip + "cmp %x[height], #0x1\n" + "add %x[out], %x[out], #0x60\n" + "bge 13b\n" + "21:" // Done + : [bias] "+&r"(bias), [height] "+&r"(height), [in] "+&r"(in), [out] "+&r"(out) + : [in_stride] "r"(in_stride), [out_stride] "r"(out_stride), [pad_row] "r"(pad_row), [width] "r"(width) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31", "x9", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon.h b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon.h new file mode 100644 index 0000000000000000000000000000000000000000..7ef908465f7f60d943645fd5c1fd7d5d42896416 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon.h @@ -0,0 +1,80 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Gets n step value. +/// +/// The starting row index must be divisible by `n_step`. +/// +/// @return The n step value. +size_t kai_get_n_step_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon(void); + +/// Gets the offset in bytes to the data element in the RHS matrix buffer. +/// +/// @param[in] n_idx Column index. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_offset_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon(size_t n_idx); + +/// Gets the offset in bytes to the data element in the bias buffer. +/// +/// @param[in] n_idx Column index. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_bias_offset_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon(size_t n_idx); + +/// Gets the offset in bytes to the data element in the packed RHS buffer. +/// +/// @param[in] n_idx Row index. +/// @param[in] k Number of columns. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_packed_offset_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon(size_t n_idx, size_t k); + +/// Gets the size in bytes of the packed RHS buffer. +/// +/// @param[in] n Number of rows. +/// @param[in] k Number of columns. +/// +/// @return The size in bytes of the packed RHS buffer. +size_t kai_get_rhs_packed_size_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon(size_t n, size_t k); + +/// Runs the RHS packing function for matrix multiplication. +/// +/// The pointer of each buffers (RHS, bias and packed RHS) needs to be added with offset +/// calculated using the following functions: +/// +/// * RHS: @ref kai_get_rhs_offset_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon. +/// * Bias: @ref kai_get_bias_offset_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon. +/// * Output: @ref kai_get_rhs_packed_offset_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon. +/// +/// @param[in] num_groups Number of groups. It must be 1. +/// @param[in] n Number of columns of the output matrix. +/// @param[in] k Common dimension between the LHS and RHS matrix. +/// @param[in] nr Block size in N dimension. It must be 12. +/// @param[in] kr Block size in K dimension. It must be 4. +/// @param[in] sr Number of kr splits. It must be 1. +/// @param[in] rhs_stride Row stride in bytes of the RHS matrix. +/// @param[in] rhs RHS matrix data buffer. +/// @param[in] bias Bias matrix data buffer. +/// @param[in] scale Scale data buffer. It must be NULL. +/// @param[out] rhs_packed Packed RHS matrix. +/// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. It must be 0. +/// @param[in] params Extra packing parameters. It must be NULL. +void kai_run_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, + const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p4x24biasf32_f32_bf16_neon.c b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p4x24biasf32_f32_bf16_neon.c new file mode 100644 index 0000000000000000000000000000000000000000..993fb6cb0b5975b536ac07af6d9d2c052922e3f1 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p4x24biasf32_f32_bf16_neon.c @@ -0,0 +1,878 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if !defined(__aarch64__) +#error This file must be compiled for AArch64. +#else // Architectural features check. +#include +#include +#include +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_nr = 24; +static const size_t kai_kr = 4; + +size_t kai_get_n_step_rhs_pack_kxn_f32p4x24biasf32_f32_bf16_neon(void) { + return kai_nr; +} + +size_t kai_get_rhs_offset_rhs_pack_kxn_f32p4x24biasf32_f32_bf16_neon(size_t n_idx) { + KAI_ASSUME(n_idx % kai_nr == 0); + + return n_idx * sizeof(float); +} + +size_t kai_get_bias_offset_rhs_pack_kxn_f32p4x24biasf32_f32_bf16_neon(size_t n_idx) { + return n_idx * sizeof(uint32_t); +} + +size_t kai_get_rhs_packed_offset_rhs_pack_kxn_f32p4x24biasf32_f32_bf16_neon(size_t n_idx, size_t k) { + KAI_ASSUME(n_idx % kai_nr == 0); + + return n_idx * (sizeof(uint32_t) + kai_roundup(k, kai_kr) * sizeof(bfloat16_t)); +} + +size_t kai_get_rhs_packed_size_rhs_pack_kxn_f32p4x24biasf32_f32_bf16_neon(size_t n, size_t k) { + return kai_get_rhs_packed_offset_rhs_pack_kxn_f32p4x24biasf32_f32_bf16_neon(kai_roundup(n, kai_nr), k); +} + +void kai_run_rhs_pack_kxn_f32p4x24biasf32_f32_bf16_neon( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, + const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params) { + KAI_ASSUME(num_groups == 1); + KAI_ASSUME(nr == kai_nr); + KAI_ASSUME(kr == kai_kr); + KAI_ASSUME(sr == 1); + KAI_ASSUME(rhs != NULL); + KAI_ASSUME(bias != NULL); + KAI_ASSUME(scale == NULL); + KAI_ASSUME(rhs_packed != NULL); + KAI_ASSUME(extra_bytes == 0); + KAI_ASSUME(params == NULL); + + size_t height = k; + const size_t width = n; + const void* in = rhs; + void* out = rhs_packed; + const size_t in_stride = rhs_stride; + float* pad_row = (float*)alloca(width * sizeof(float)); + + if (height % 4) { + memset(pad_row, 0, width * sizeof(float)); + } + + size_t out_stride = kai_nr * kai_roundup(height, 4) * sizeof(bfloat16_t) + kai_nr * sizeof(uint32_t); + + __asm__ __volatile__( + "mov x22, %x[width]\n" + "mov x21, %x[out]\n" + "cmp x22, #0x18\n" + "blt 2f\n" + "1:" // Bias: Full loop + "ldr q3, [%x[bias], #0x0]\n" + "ldr q28, [%x[bias], #0x10]\n" + "sub x22, x22, #0x18\n" + "ldr q14, [%x[bias], #0x20]\n" + "ldr q0, [%x[bias], #0x30]\n" + "cmp x22, #0x18\n" + "ldr q1, [%x[bias], #0x40]\n" + "ldr q7, [%x[bias], #0x50]\n" + "add %x[bias], %x[bias], #0x60\n" + "str q3, [x21, #0x0]\n" + "str q28, [x21, #0x10]\n" + "str q14, [x21, #0x20]\n" + "str q0, [x21, #0x30]\n" + "str q1, [x21, #0x40]\n" + "str q7, [x21, #0x50]\n" + "add x21, x21, %x[out_stride]\n" + "bge 1b\n" + "cbz x22, 3f\n" + "2:" // Bias: Tail loop + "ldr w20, [%x[bias], #0x0]\n" + "sub x22, x22, #0x1\n" + "add %x[bias], %x[bias], #0x4\n" + "cmp x22, #0x0\n" + "str w20, [x21]\n" + "add x21, x21, #0x4\n" + "bgt 2b\n" + "3:" // Bias: Done + "cmp %x[height], #0x8\n" + "add %x[out], %x[out], #0x60\n" + "blt 14f\n" + "4:" // Main row loop: Head + "mov x9, %x[in]\n" + "mov x28, %x[width]\n" + "mov x27, %x[out]\n" + "sub %x[height], %x[height], #0x8\n" + "add x26, x9, %x[in_stride]\n" + "add x25, x26, %x[in_stride]\n" + "add x24, x25, %x[in_stride]\n" + "cmp x28, #0x18\n" + "add x23, x24, %x[in_stride]\n" + "add x22, x23, %x[in_stride]\n" + "add x21, x22, %x[in_stride]\n" + "add x20, x21, %x[in_stride]\n" + "add %x[in], x20, %x[in_stride]\n" + "blt 6f\n" + "5:" // Main row loop: Column loop + "ldr q14, [x9], #0x10\n" + "ldr q13, [x26], #0x10\n" + "sub x28, x28, #0x18\n" + "ldr q16, [x25], #0x10\n" + "ldr q15, [x24], #0x10\n" + "cmp x28, #0x18\n" + "ldr q12, [x23], #0x10\n" + "ldr q20, [x22], #0x10\n" + "ldr q26, [x21], #0x10\n" + "ldr q4, [x20], #0x10\n" + "ldr q10, [x9], #0x10\n" + "ldr q25, [x26], #0x10\n" + "zip1 v30.4s, v14.4s, v16.4s\n" + "zip1 v17.4s, v13.4s, v15.4s\n" + "ldr q22, [x25], #0x10\n" + "ldr q21, [x24], #0x10\n" + "zip2 v14.4s, v14.4s, v16.4s\n" + "zip2 v24.4s, v13.4s, v15.4s\n" + "ldr q5, [x23], #0x10\n" + "ldr q0, [x22], #0x10\n" + "zip1 v9.4s, v12.4s, v26.4s\n" + "zip1 v13.4s, v20.4s, v4.4s\n" + "ldr q1, [x21], #0x10\n" + "ldr q15, [x20], #0x10\n" + "zip2 v19.4s, v12.4s, v26.4s\n" + "zip2 v16.4s, v20.4s, v4.4s\n" + "ldr q20, [x9], #0x10\n" + "ldr q2, [x26], #0x10\n" + "zip1 v29.4s, v10.4s, v22.4s\n" + "zip1 v28.4s, v25.4s, v21.4s\n" + "ldr q18, [x25], #0x10\n" + "ldr q6, [x24], #0x10\n" + "zip2 v22.4s, v10.4s, v22.4s\n" + "zip2 v25.4s, v25.4s, v21.4s\n" + "ldr q7, [x23], #0x10\n" + "ldr q23, [x22], #0x10\n" + "zip1 v3.4s, v5.4s, v1.4s\n" + "zip1 v11.4s, v0.4s, v15.4s\n" + "ldr q26, [x21], #0x10\n" + "ldr q4, [x20], #0x10\n" + "zip2 v10.4s, v5.4s, v1.4s\n" + "zip2 v8.4s, v0.4s, v15.4s\n" + "ldr q1, [x9], #0x10\n" + "ldr q15, [x26], #0x10\n" + "zip1 v31.4s, v20.4s, v18.4s\n" + "zip1 v12.4s, v2.4s, v6.4s\n" + "ldr q5, [x25], #0x10\n" + "ldr q0, [x24], #0x10\n" + "zip2 v27.4s, v20.4s, v18.4s\n" + "zip2 v6.4s, v2.4s, v6.4s\n" + "ldr q18, [x23], #0x10\n" + "ldr q20, [x22], #0x10\n" + "zip1 v21.4s, v7.4s, v26.4s\n" + "zip1 v2.4s, v23.4s, v4.4s\n" + "zip2 v7.4s, v7.4s, v26.4s\n" + "ldr q26, [x21], #0x10\n" + "zip2 v23.4s, v23.4s, v4.4s\n" + "zip1 v4.4s, v1.4s, v5.4s\n" + "zip2 v5.4s, v1.4s, v5.4s\n" + "zip1 v1.4s, v15.4s, v0.4s\n" + "zip2 v15.4s, v15.4s, v0.4s\n" + "zip1 v0.4s, v18.4s, v26.4s\n" + "zip2 v18.4s, v18.4s, v26.4s\n" + "zip1 v26.4s, v30.4s, v17.4s\n" + "zip2 v17.4s, v30.4s, v17.4s\n" + "ldr q30, [x20], #0x10\n" + ".inst 0x0ea16b5a // bfcvtn v26.4h, v26.4s\n" + ".inst 0x4ea16a3a // bfcvtn2 v26.8h, v17.4s\n" + "zip1 v17.4s, v14.4s, v24.4s\n" + "zip2 v14.4s, v14.4s, v24.4s\n" + "ldr q24, [x9], #0x10\n" + ".inst 0x0ea16a31 // bfcvtn v17.4h, v17.4s\n" + ".inst 0x4ea169d1 // bfcvtn2 v17.8h, v14.4s\n" + "zip1 v14.4s, v20.4s, v30.4s\n" + "zip2 v30.4s, v20.4s, v30.4s\n" + "zip1 v20.4s, v29.4s, v28.4s\n" + "zip2 v28.4s, v29.4s, v28.4s\n" + "ldr q29, [x26], #0x10\n" + ".inst 0x0ea16a94 // bfcvtn v20.4h, v20.4s\n" + ".inst 0x4ea16b94 // bfcvtn2 v20.8h, v28.4s\n" + "zip1 v28.4s, v22.4s, v25.4s\n" + "zip2 v22.4s, v22.4s, v25.4s\n" + "ldr q25, [x25], #0x10\n" + ".inst 0x0ea16b9c // bfcvtn v28.4h, v28.4s\n" + ".inst 0x4ea16adc // bfcvtn2 v28.8h, v22.4s\n" + "zip1 v22.4s, v31.4s, v12.4s\n" + "zip2 v31.4s, v31.4s, v12.4s\n" + "ldr q12, [x24], #0x10\n" + ".inst 0x0ea16ad6 // bfcvtn v22.4h, v22.4s\n" + ".inst 0x4ea16bf6 // bfcvtn2 v22.8h, v31.4s\n" + "zip1 v31.4s, v24.4s, v25.4s\n" + "zip2 v24.4s, v24.4s, v25.4s\n" + "zip1 v25.4s, v29.4s, v12.4s\n" + "zip2 v12.4s, v29.4s, v12.4s\n" + "zip1 v29.4s, v27.4s, v6.4s\n" + "zip2 v27.4s, v27.4s, v6.4s\n" + "ldr q6, [x23], #0x10\n" + ".inst 0x0ea16bbd // bfcvtn v29.4h, v29.4s\n" + ".inst 0x4ea16b7d // bfcvtn2 v29.8h, v27.4s\n" + "zip1 v27.4s, v4.4s, v1.4s\n" + "zip2 v4.4s, v4.4s, v1.4s\n" + "ldr q1, [x22], #0x10\n" + ".inst 0x0ea16b7b // bfcvtn v27.4h, v27.4s\n" + ".inst 0x4ea1689b // bfcvtn2 v27.8h, v4.4s\n" + "zip1 v4.4s, v5.4s, v15.4s\n" + "zip2 v5.4s, v5.4s, v15.4s\n" + "ldr q15, [x21], #0x10\n" + ".inst 0x0ea16884 // bfcvtn v4.4h, v4.4s\n" + ".inst 0x4ea168a4 // bfcvtn2 v4.8h, v5.4s\n" + "zip1 v5.4s, v31.4s, v25.4s\n" + "zip2 v25.4s, v31.4s, v25.4s\n" + "ldr q31, [x20], #0x10\n" + ".inst 0x0ea168a5 // bfcvtn v5.4h, v5.4s\n" + ".inst 0x4ea16b25 // bfcvtn2 v5.8h, v25.4s\n" + "zip1 v25.4s, v6.4s, v15.4s\n" + "zip2 v15.4s, v6.4s, v15.4s\n" + "zip1 v6.4s, v1.4s, v31.4s\n" + "zip2 v31.4s, v1.4s, v31.4s\n" + "zip1 v1.4s, v24.4s, v12.4s\n" + "zip2 v12.4s, v24.4s, v12.4s\n" + "ldr q24, [x9], #0x10\n" + ".inst 0x0ea16821 // bfcvtn v1.4h, v1.4s\n" + ".inst 0x4ea16981 // bfcvtn2 v1.8h, v12.4s\n" + "zip1 v12.4s, v9.4s, v13.4s\n" + "zip2 v13.4s, v9.4s, v13.4s\n" + "ldr q9, [x26], #0x10\n" + ".inst 0x0ea1698c // bfcvtn v12.4h, v12.4s\n" + ".inst 0x4ea169ac // bfcvtn2 v12.8h, v13.4s\n" + "zip1 v13.4s, v19.4s, v16.4s\n" + "zip2 v19.4s, v19.4s, v16.4s\n" + "ldr q16, [x25], #0x10\n" + ".inst 0x0ea169ad // bfcvtn v13.4h, v13.4s\n" + ".inst 0x4ea16a6d // bfcvtn2 v13.8h, v19.4s\n" + "zip1 v19.4s, v3.4s, v11.4s\n" + "zip2 v11.4s, v3.4s, v11.4s\n" + "ldr q3, [x24], #0x10\n" + ".inst 0x0ea16a73 // bfcvtn v19.4h, v19.4s\n" + ".inst 0x4ea16973 // bfcvtn2 v19.8h, v11.4s\n" + "zip1 v11.4s, v24.4s, v16.4s\n" + "zip2 v16.4s, v24.4s, v16.4s\n" + "zip1 v24.4s, v9.4s, v3.4s\n" + "zip2 v3.4s, v9.4s, v3.4s\n" + "zip1 v9.4s, v10.4s, v8.4s\n" + "zip2 v8.4s, v10.4s, v8.4s\n" + "ldr q10, [x23], #0x10\n" + ".inst 0x0ea16929 // bfcvtn v9.4h, v9.4s\n" + ".inst 0x4ea16909 // bfcvtn2 v9.8h, v8.4s\n" + "zip1 v8.4s, v11.4s, v24.4s\n" + "zip2 v24.4s, v11.4s, v24.4s\n" + "ldr q11, [x22], #0x10\n" + ".inst 0x0ea16908 // bfcvtn v8.4h, v8.4s\n" + ".inst 0x4ea16b08 // bfcvtn2 v8.8h, v24.4s\n" + "zip1 v24.4s, v16.4s, v3.4s\n" + "zip2 v3.4s, v16.4s, v3.4s\n" + "ldr q16, [x21], #0x10\n" + ".inst 0x0ea16b18 // bfcvtn v24.4h, v24.4s\n" + ".inst 0x4ea16878 // bfcvtn2 v24.8h, v3.4s\n" + "ldr q3, [x20], #0x10\n" + "str q26, [x27, #0x0]\n" + "zip1 v26.4s, v21.4s, v2.4s\n" + "zip2 v21.4s, v21.4s, v2.4s\n" + "str q17, [x27, #0x10]\n" + "zip1 v17.4s, v10.4s, v16.4s\n" + "zip2 v16.4s, v10.4s, v16.4s\n" + "str q20, [x27, #0x20]\n" + "zip1 v20.4s, v11.4s, v3.4s\n" + "zip2 v10.4s, v11.4s, v3.4s\n" + "str q28, [x27, #0x30]\n" + "zip1 v28.4s, v7.4s, v23.4s\n" + "zip1 v3.4s, v0.4s, v14.4s\n" + "str q22, [x27, #0x40]\n" + "zip1 v11.4s, v18.4s, v30.4s\n" + "zip1 v2.4s, v25.4s, v6.4s\n" + "str q29, [x27, #0x50]\n" + "zip1 v22.4s, v15.4s, v31.4s\n" + "zip1 v29.4s, v17.4s, v20.4s\n" + "str q27, [x27, #0x60]\n" + "zip1 v27.4s, v16.4s, v10.4s\n" + ".inst 0x0ea16b5a // bfcvtn v26.4h, v26.4s\n" + "str q4, [x27, #0x70]\n" + ".inst 0x0ea16b84 // bfcvtn v4.4h, v28.4s\n" + "zip2 v7.4s, v7.4s, v23.4s\n" + "str q5, [x27, #0x80]\n" + ".inst 0x0ea16877 // bfcvtn v23.4h, v3.4s\n" + "zip2 v28.4s, v0.4s, v14.4s\n" + "str q1, [x27, #0x90]\n" + ".inst 0x0ea1696e // bfcvtn v14.4h, v11.4s\n" + "zip2 v1.4s, v18.4s, v30.4s\n" + "str q8, [x27, #0xa0]\n" + ".inst 0x0ea16848 // bfcvtn v8.4h, v2.4s\n" + "zip2 v0.4s, v25.4s, v6.4s\n" + "str q24, [x27, #0xb0]\n" + ".inst 0x0ea16ac5 // bfcvtn v5.4h, v22.4s\n" + "zip2 v6.4s, v15.4s, v31.4s\n" + "str q12, [x27, #0xc0]\n" + ".inst 0x0ea16ba2 // bfcvtn v2.4h, v29.4s\n" + "zip2 v12.4s, v17.4s, v20.4s\n" + "str q13, [x27, #0xd0]\n" + ".inst 0x0ea16b7e // bfcvtn v30.4h, v27.4s\n" + "zip2 v17.4s, v16.4s, v10.4s\n" + "str q19, [x27, #0xe0]\n" + ".inst 0x4ea16aba // bfcvtn2 v26.8h, v21.4s\n" + ".inst 0x4ea168e4 // bfcvtn2 v4.8h, v7.4s\n" + "str q9, [x27, #0xf0]\n" + ".inst 0x4ea16b97 // bfcvtn2 v23.8h, v28.4s\n" + ".inst 0x4ea1682e // bfcvtn2 v14.8h, v1.4s\n" + ".inst 0x4ea16808 // bfcvtn2 v8.8h, v0.4s\n" + ".inst 0x4ea168c5 // bfcvtn2 v5.8h, v6.4s\n" + ".inst 0x4ea16982 // bfcvtn2 v2.8h, v12.4s\n" + ".inst 0x4ea16a3e // bfcvtn2 v30.8h, v17.4s\n" + "str q26, [x27, #0x100]\n" + "str q4, [x27, #0x110]\n" + "str q23, [x27, #0x120]\n" + "str q14, [x27, #0x130]\n" + "str q8, [x27, #0x140]\n" + "str q5, [x27, #0x150]\n" + "str q2, [x27, #0x160]\n" + "str q30, [x27, #0x170]\n" + "add x27, x27, %x[out_stride]\n" + "bge 5b\n" + "6:" // Main row loop: Column loop skip + "cbz x28, 13f\n" + "cmp x28, #0x10\n" + "movi v15.16b, #0x0\n" + "str q15, [x27, #0x0]\n" + "str q15, [x27, #0x10]\n" + "str q15, [x27, #0x20]\n" + "str q15, [x27, #0x30]\n" + "str q15, [x27, #0x40]\n" + "str q15, [x27, #0x50]\n" + "str q15, [x27, #0x60]\n" + "str q15, [x27, #0x70]\n" + "str q15, [x27, #0x80]\n" + "str q15, [x27, #0x90]\n" + "str q15, [x27, #0xa0]\n" + "str q15, [x27, #0xb0]\n" + "str q15, [x27, #0xc0]\n" + "str q15, [x27, #0xd0]\n" + "str q15, [x27, #0xe0]\n" + "str q15, [x27, #0xf0]\n" + "str q15, [x27, #0x100]\n" + "str q15, [x27, #0x110]\n" + "str q15, [x27, #0x120]\n" + "str q15, [x27, #0x130]\n" + "str q15, [x27, #0x140]\n" + "str q15, [x27, #0x150]\n" + "str q15, [x27, #0x160]\n" + "str q15, [x27, #0x170]\n" + "blt 8f\n" + "7:" // Main row loop: width 16 loop: loop + "ldr q18, [x9], #0x10\n" + "ldr q15, [x26], #0x10\n" + "sub x28, x28, #0x10\n" + "ldr q17, [x25], #0x10\n" + "ldr q16, [x24], #0x10\n" + "cmp x28, #0x10\n" + "ldr q20, [x23], #0x10\n" + "ldr q6, [x22], #0x10\n" + "ldr q31, [x21], #0x10\n" + "ldr q14, [x20], #0x10\n" + "ldr q26, [x9], #0x10\n" + "ldr q30, [x26], #0x10\n" + "zip1 v1.4s, v18.4s, v17.4s\n" + "zip1 v24.4s, v15.4s, v16.4s\n" + "ldr q3, [x25], #0x10\n" + "ldr q12, [x24], #0x10\n" + "zip2 v11.4s, v18.4s, v17.4s\n" + "zip2 v17.4s, v15.4s, v16.4s\n" + "ldr q15, [x23], #0x10\n" + "ldr q7, [x22], #0x10\n" + "zip1 v21.4s, v20.4s, v31.4s\n" + "zip1 v9.4s, v6.4s, v14.4s\n" + "ldr q18, [x21], #0x10\n" + "ldr q5, [x20], #0x10\n" + "zip2 v2.4s, v20.4s, v31.4s\n" + "zip2 v22.4s, v6.4s, v14.4s\n" + "ldr q0, [x9], #0x10\n" + "ldr q4, [x26], #0x10\n" + "zip1 v28.4s, v26.4s, v3.4s\n" + "zip1 v29.4s, v30.4s, v12.4s\n" + "ldr q6, [x25], #0x10\n" + "ldr q8, [x24], #0x10\n" + "zip2 v27.4s, v26.4s, v3.4s\n" + "zip2 v3.4s, v30.4s, v12.4s\n" + "ldr q10, [x23], #0x10\n" + "ldr q23, [x22], #0x10\n" + "zip1 v20.4s, v15.4s, v18.4s\n" + "zip1 v19.4s, v7.4s, v5.4s\n" + "ldr q13, [x21], #0x10\n" + "ldr q16, [x20], #0x10\n" + "zip2 v31.4s, v15.4s, v18.4s\n" + "zip2 v30.4s, v7.4s, v5.4s\n" + "ldr q5, [x9], #0x10\n" + "ldr q25, [x26], #0x10\n" + "zip1 v7.4s, v0.4s, v6.4s\n" + "zip1 v15.4s, v4.4s, v8.4s\n" + "ldr q18, [x25], #0x10\n" + "ldr q14, [x24], #0x10\n" + "zip2 v6.4s, v0.4s, v6.4s\n" + "zip2 v4.4s, v4.4s, v8.4s\n" + "ldr q8, [x23], #0x10\n" + "ldr q12, [x22], #0x10\n" + "zip1 v26.4s, v10.4s, v13.4s\n" + "zip1 v0.4s, v23.4s, v16.4s\n" + "zip2 v10.4s, v10.4s, v13.4s\n" + "ldr q13, [x21], #0x10\n" + "zip2 v16.4s, v23.4s, v16.4s\n" + "zip1 v23.4s, v5.4s, v18.4s\n" + "zip2 v5.4s, v5.4s, v18.4s\n" + "zip1 v18.4s, v25.4s, v14.4s\n" + "zip2 v25.4s, v25.4s, v14.4s\n" + "zip1 v14.4s, v8.4s, v13.4s\n" + "zip2 v13.4s, v8.4s, v13.4s\n" + "zip1 v8.4s, v1.4s, v24.4s\n" + "zip2 v24.4s, v1.4s, v24.4s\n" + "ldr q1, [x20], #0x10\n" + ".inst 0x0ea16908 // bfcvtn v8.4h, v8.4s\n" + ".inst 0x4ea16b08 // bfcvtn2 v8.8h, v24.4s\n" + "zip1 v24.4s, v11.4s, v17.4s\n" + "zip2 v11.4s, v11.4s, v17.4s\n" + "zip1 v17.4s, v12.4s, v1.4s\n" + "zip2 v1.4s, v12.4s, v1.4s\n" + "zip1 v12.4s, v28.4s, v29.4s\n" + "zip2 v28.4s, v28.4s, v29.4s\n" + "str q8, [x27, #0x0]\n" + "zip1 v29.4s, v27.4s, v3.4s\n" + "zip1 v8.4s, v7.4s, v15.4s\n" + ".inst 0x0ea16b18 // bfcvtn v24.4h, v24.4s\n" + "zip2 v27.4s, v27.4s, v3.4s\n" + "zip1 v3.4s, v6.4s, v4.4s\n" + ".inst 0x0ea1698c // bfcvtn v12.4h, v12.4s\n" + ".inst 0x0ea16bbd // bfcvtn v29.4h, v29.4s\n" + ".inst 0x0ea16908 // bfcvtn v8.4h, v8.4s\n" + "zip2 v15.4s, v7.4s, v15.4s\n" + "zip1 v7.4s, v23.4s, v18.4s\n" + ".inst 0x0ea16863 // bfcvtn v3.4h, v3.4s\n" + "zip2 v6.4s, v6.4s, v4.4s\n" + "zip1 v4.4s, v5.4s, v25.4s\n" + "zip2 v18.4s, v23.4s, v18.4s\n" + "zip1 v23.4s, v21.4s, v9.4s\n" + ".inst 0x0ea168e7 // bfcvtn v7.4h, v7.4s\n" + "zip2 v5.4s, v5.4s, v25.4s\n" + "zip1 v25.4s, v2.4s, v22.4s\n" + ".inst 0x0ea16884 // bfcvtn v4.4h, v4.4s\n" + "zip2 v9.4s, v21.4s, v9.4s\n" + "zip1 v21.4s, v20.4s, v19.4s\n" + ".inst 0x0ea16af7 // bfcvtn v23.4h, v23.4s\n" + ".inst 0x0ea16b39 // bfcvtn v25.4h, v25.4s\n" + "zip2 v22.4s, v2.4s, v22.4s\n" + "zip1 v2.4s, v31.4s, v30.4s\n" + "zip2 v20.4s, v20.4s, v19.4s\n" + "zip1 v19.4s, v26.4s, v0.4s\n" + ".inst 0x0ea16ab5 // bfcvtn v21.4h, v21.4s\n" + "zip2 v31.4s, v31.4s, v30.4s\n" + "zip1 v30.4s, v10.4s, v16.4s\n" + ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n" + "zip2 v26.4s, v26.4s, v0.4s\n" + "zip1 v0.4s, v14.4s, v17.4s\n" + ".inst 0x0ea16a73 // bfcvtn v19.4h, v19.4s\n" + ".inst 0x0ea16bde // bfcvtn v30.4h, v30.4s\n" + "zip2 v16.4s, v10.4s, v16.4s\n" + "zip1 v10.4s, v13.4s, v1.4s\n" + "zip2 v14.4s, v14.4s, v17.4s\n" + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + "zip2 v17.4s, v13.4s, v1.4s\n" + ".inst 0x4ea16978 // bfcvtn2 v24.8h, v11.4s\n" + ".inst 0x4ea16b8c // bfcvtn2 v12.8h, v28.4s\n" + ".inst 0x0ea1695c // bfcvtn v28.4h, v10.4s\n" + ".inst 0x4ea16b7d // bfcvtn2 v29.8h, v27.4s\n" + ".inst 0x4ea169e8 // bfcvtn2 v8.8h, v15.4s\n" + ".inst 0x4ea168c3 // bfcvtn2 v3.8h, v6.4s\n" + ".inst 0x4ea16a47 // bfcvtn2 v7.8h, v18.4s\n" + ".inst 0x4ea168a4 // bfcvtn2 v4.8h, v5.4s\n" + "str q24, [x27, #0x10]\n" + ".inst 0x4ea16937 // bfcvtn2 v23.8h, v9.4s\n" + ".inst 0x4ea16ad9 // bfcvtn2 v25.8h, v22.4s\n" + "str q12, [x27, #0x20]\n" + ".inst 0x4ea16a95 // bfcvtn2 v21.8h, v20.4s\n" + ".inst 0x4ea16be2 // bfcvtn2 v2.8h, v31.4s\n" + "str q29, [x27, #0x30]\n" + ".inst 0x4ea16b53 // bfcvtn2 v19.8h, v26.4s\n" + ".inst 0x4ea16a1e // bfcvtn2 v30.8h, v16.4s\n" + "str q8, [x27, #0x40]\n" + ".inst 0x4ea169c0 // bfcvtn2 v0.8h, v14.4s\n" + ".inst 0x4ea16a3c // bfcvtn2 v28.8h, v17.4s\n" + "str q3, [x27, #0x50]\n" + "str q7, [x27, #0x60]\n" + "str q4, [x27, #0x70]\n" + "str q23, [x27, #0xc0]\n" + "str q25, [x27, #0xd0]\n" + "str q21, [x27, #0xe0]\n" + "str q2, [x27, #0xf0]\n" + "str q19, [x27, #0x100]\n" + "str q30, [x27, #0x110]\n" + "str q0, [x27, #0x120]\n" + "str q28, [x27, #0x130]\n" + "add x27, x27, #0x80\n" + "bge 7b\n" + "8:" // Main row loop: width 16 loop: skip + "cmp x28, #0x4\n" + "blt 10f\n" + "9:" // Main row loop: width 4 loop: loop + "ldr q25, [x9], #0x10\n" + "ldr q24, [x26], #0x10\n" + "sub x28, x28, #0x4\n" + "ldr q21, [x25], #0x10\n" + "ldr q20, [x24], #0x10\n" + "cmp x28, #0x4\n" + "ldr q23, [x23], #0x10\n" + "ldr q19, [x22], #0x10\n" + "ldr q18, [x21], #0x10\n" + "ldr q17, [x20], #0x10\n" + "zip1 v22.4s, v25.4s, v21.4s\n" + "zip1 v16.4s, v24.4s, v20.4s\n" + "zip2 v21.4s, v25.4s, v21.4s\n" + "zip2 v20.4s, v24.4s, v20.4s\n" + "zip1 v27.4s, v23.4s, v18.4s\n" + "zip1 v26.4s, v19.4s, v17.4s\n" + "zip2 v25.4s, v23.4s, v18.4s\n" + "zip2 v24.4s, v19.4s, v17.4s\n" + "zip1 v19.4s, v22.4s, v16.4s\n" + "zip1 v18.4s, v21.4s, v20.4s\n" + "zip1 v17.4s, v27.4s, v26.4s\n" + "zip2 v23.4s, v22.4s, v16.4s\n" + "zip1 v16.4s, v25.4s, v24.4s\n" + "zip2 v22.4s, v21.4s, v20.4s\n" + ".inst 0x0ea16a75 // bfcvtn v21.4h, v19.4s\n" + ".inst 0x0ea16a54 // bfcvtn v20.4h, v18.4s\n" + ".inst 0x0ea16a33 // bfcvtn v19.4h, v17.4s\n" + "zip2 v18.4s, v27.4s, v26.4s\n" + ".inst 0x0ea16a11 // bfcvtn v17.4h, v16.4s\n" + "zip2 v16.4s, v25.4s, v24.4s\n" + ".inst 0x4ea16af5 // bfcvtn2 v21.8h, v23.4s\n" + ".inst 0x4ea16ad4 // bfcvtn2 v20.8h, v22.4s\n" + ".inst 0x4ea16a53 // bfcvtn2 v19.8h, v18.4s\n" + ".inst 0x4ea16a11 // bfcvtn2 v17.8h, v16.4s\n" + "str q21, [x27, #0x0]\n" + "str q20, [x27, #0x10]\n" + "str q19, [x27, #0xc0]\n" + "str q17, [x27, #0xd0]\n" + "add x27, x27, #0x20\n" + "bge 9b\n" + "10:" // Main row loop: width 4 loop: skip + "cmp x28, #0x1\n" + "blt 12f\n" + "11:" // Main row loop: width 1 loop: loop + "ldr s23, [x9], #0x4\n" + "ldr s22, [x26], #0x4\n" + "sub x28, x28, #0x1\n" + "ldr s19, [x25], #0x4\n" + "ldr s17, [x24], #0x4\n" + "cmp x28, #0x1\n" + "ldr s21, [x23], #0x4\n" + "ldr s20, [x22], #0x4\n" + "ldr s18, [x21], #0x4\n" + "ldr s16, [x20], #0x4\n" + "zip1 v19.4s, v23.4s, v19.4s\n" + "zip1 v17.4s, v22.4s, v17.4s\n" + "zip1 v18.4s, v21.4s, v18.4s\n" + "zip1 v16.4s, v20.4s, v16.4s\n" + "zip1 v17.4s, v19.4s, v17.4s\n" + "zip1 v16.4s, v18.4s, v16.4s\n" + ".inst 0x0ea16a31 // bfcvtn v17.4h, v17.4s\n" + ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n" + "str d17, [x27, #0x0]\n" + "str d16, [x27, #0xc0]\n" + "add x27, x27, #0x8\n" + "bge 11b\n" + "12:" // Main row loop: width 1 loop: skip + "13:" // Main row loop: odd col skip + "cmp %x[height], #0x8\n" + "add %x[out], %x[out], #0x180\n" + "bge 4b\n" + "cbz %x[height], 25f\n" + "14:" // Main loop skip + "15:" // Tail row loop: Head + "mov x9, %x[in]\n" + "mov x20, %x[width]\n" + "cmp %x[height], #0x3\n" + "mov x27, %x[out]\n" + "add x26, x9, %x[in_stride]\n" + "add x25, x26, %x[in_stride]\n" + "add x24, x25, %x[in_stride]\n" + "csel x25, x25, %x[pad_row], GE\n" + "add %x[in], x24, %x[in_stride]\n" + "csel x24, x24, %x[pad_row], GT\n" + "cmp %x[height], #0x1\n" + "sub %x[height], %x[height], #0x4\n" + "csel x26, x26, %x[pad_row], GT\n" + "cmp x20, #0x18\n" + "blt 17f\n" + "16:" // Tail row loop: Column loop + "ldr q24, [x9], #0x10\n" + "ldr q21, [x26], #0x10\n" + "sub x20, x20, #0x18\n" + "ldr q19, [x25], #0x10\n" + "ldr q17, [x24], #0x10\n" + "cmp x20, #0x18\n" + "ldr q23, [x9], #0x10\n" + "ldr q20, [x26], #0x10\n" + "ldr q18, [x25], #0x10\n" + "ldr q16, [x24], #0x10\n" + "ldr q29, [x9], #0x10\n" + "zip1 v22.4s, v24.4s, v19.4s\n" + "zip1 v6.4s, v21.4s, v17.4s\n" + "ldr q28, [x26], #0x10\n" + "ldr q27, [x25], #0x10\n" + "zip2 v19.4s, v24.4s, v19.4s\n" + "zip2 v0.4s, v21.4s, v17.4s\n" + "ldr q21, [x24], #0x10\n" + "ldr q1, [x9], #0x10\n" + "zip1 v17.4s, v23.4s, v18.4s\n" + "zip1 v4.4s, v20.4s, v16.4s\n" + "ldr q31, [x26], #0x10\n" + "ldr q24, [x25], #0x10\n" + "zip2 v5.4s, v23.4s, v18.4s\n" + "zip2 v16.4s, v20.4s, v16.4s\n" + "ldr q23, [x24], #0x10\n" + "ldr q25, [x9], #0x10\n" + "zip1 v30.4s, v29.4s, v27.4s\n" + "zip1 v26.4s, v28.4s, v21.4s\n" + "ldr q20, [x26], #0x10\n" + "ldr q18, [x25], #0x10\n" + "zip2 v29.4s, v29.4s, v27.4s\n" + "zip2 v28.4s, v28.4s, v21.4s\n" + "ldr q13, [x24], #0x10\n" + "ldr q9, [x9], #0x10\n" + "zip1 v27.4s, v1.4s, v24.4s\n" + "zip1 v21.4s, v31.4s, v23.4s\n" + "ldr q8, [x26], #0x10\n" + "ldr q7, [x25], #0x10\n" + "zip2 v24.4s, v1.4s, v24.4s\n" + "zip2 v14.4s, v31.4s, v23.4s\n" + "ldr q1, [x24], #0x10\n" + "zip1 v23.4s, v25.4s, v18.4s\n" + "zip1 v3.4s, v20.4s, v13.4s\n" + "zip2 v12.4s, v25.4s, v18.4s\n" + "zip2 v13.4s, v20.4s, v13.4s\n" + "zip1 v15.4s, v9.4s, v7.4s\n" + "zip1 v18.4s, v8.4s, v1.4s\n" + "zip2 v9.4s, v9.4s, v7.4s\n" + "zip2 v10.4s, v8.4s, v1.4s\n" + "zip1 v11.4s, v22.4s, v6.4s\n" + "zip1 v7.4s, v19.4s, v0.4s\n" + "zip1 v1.4s, v17.4s, v4.4s\n" + "zip1 v2.4s, v5.4s, v16.4s\n" + "zip1 v31.4s, v30.4s, v26.4s\n" + "zip1 v25.4s, v29.4s, v28.4s\n" + "zip1 v8.4s, v27.4s, v21.4s\n" + "zip1 v20.4s, v24.4s, v14.4s\n" + ".inst 0x0ea1696b // bfcvtn v11.4h, v11.4s\n" + "zip2 v6.4s, v22.4s, v6.4s\n" + "zip1 v22.4s, v23.4s, v3.4s\n" + ".inst 0x0ea168e7 // bfcvtn v7.4h, v7.4s\n" + "zip2 v0.4s, v19.4s, v0.4s\n" + "zip1 v19.4s, v12.4s, v13.4s\n" + ".inst 0x0ea16821 // bfcvtn v1.4h, v1.4s\n" + "zip2 v4.4s, v17.4s, v4.4s\n" + "zip1 v17.4s, v15.4s, v18.4s\n" + ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n" + "zip2 v5.4s, v5.4s, v16.4s\n" + "zip1 v16.4s, v9.4s, v10.4s\n" + ".inst 0x0ea16bff // bfcvtn v31.4h, v31.4s\n" + "zip2 v26.4s, v30.4s, v26.4s\n" + ".inst 0x0ea16b3e // bfcvtn v30.4h, v25.4s\n" + "zip2 v29.4s, v29.4s, v28.4s\n" + ".inst 0x0ea1691c // bfcvtn v28.4h, v8.4s\n" + "zip2 v27.4s, v27.4s, v21.4s\n" + ".inst 0x0ea16a95 // bfcvtn v21.4h, v20.4s\n" + "zip2 v25.4s, v24.4s, v14.4s\n" + ".inst 0x0ea16ad8 // bfcvtn v24.4h, v22.4s\n" + "zip2 v22.4s, v23.4s, v3.4s\n" + ".inst 0x0ea16a6e // bfcvtn v14.4h, v19.4s\n" + "zip2 v20.4s, v12.4s, v13.4s\n" + ".inst 0x0ea16a33 // bfcvtn v19.4h, v17.4s\n" + "zip2 v18.4s, v15.4s, v18.4s\n" + ".inst 0x0ea16a11 // bfcvtn v17.4h, v16.4s\n" + "zip2 v16.4s, v9.4s, v10.4s\n" + ".inst 0x4ea168cb // bfcvtn2 v11.8h, v6.4s\n" + ".inst 0x4ea16807 // bfcvtn2 v7.8h, v0.4s\n" + ".inst 0x4ea16881 // bfcvtn2 v1.8h, v4.4s\n" + ".inst 0x4ea168a2 // bfcvtn2 v2.8h, v5.4s\n" + ".inst 0x4ea16b5f // bfcvtn2 v31.8h, v26.4s\n" + ".inst 0x4ea16bbe // bfcvtn2 v30.8h, v29.4s\n" + ".inst 0x4ea16b7c // bfcvtn2 v28.8h, v27.4s\n" + ".inst 0x4ea16b35 // bfcvtn2 v21.8h, v25.4s\n" + "str q11, [x27, #0x0]\n" + ".inst 0x4ea16ad8 // bfcvtn2 v24.8h, v22.4s\n" + ".inst 0x4ea16a8e // bfcvtn2 v14.8h, v20.4s\n" + "str q7, [x27, #0x10]\n" + ".inst 0x4ea16a53 // bfcvtn2 v19.8h, v18.4s\n" + ".inst 0x4ea16a11 // bfcvtn2 v17.8h, v16.4s\n" + "str q1, [x27, #0x20]\n" + "str q2, [x27, #0x30]\n" + "str q31, [x27, #0x40]\n" + "str q30, [x27, #0x50]\n" + "str q28, [x27, #0x60]\n" + "str q21, [x27, #0x70]\n" + "str q24, [x27, #0x80]\n" + "str q14, [x27, #0x90]\n" + "str q19, [x27, #0xa0]\n" + "str q17, [x27, #0xb0]\n" + "add x27, x27, %x[out_stride]\n" + "bge 16b\n" + "17:" // Tail row loop: Column loop skip + "cbz x20, 24f\n" + "cmp x20, #0x10\n" + "movi v16.16b, #0x0\n" + "str q16, [x27, #0x0]\n" + "str q16, [x27, #0x10]\n" + "str q16, [x27, #0x20]\n" + "str q16, [x27, #0x30]\n" + "str q16, [x27, #0x40]\n" + "str q16, [x27, #0x50]\n" + "str q16, [x27, #0x60]\n" + "str q16, [x27, #0x70]\n" + "str q16, [x27, #0x80]\n" + "str q16, [x27, #0x90]\n" + "str q16, [x27, #0xa0]\n" + "str q16, [x27, #0xb0]\n" + "blt 19f\n" + "18:" // Tail row loop: width 16 loop: loop + "ldr q20, [x9], #0x10\n" + "ldr q19, [x26], #0x10\n" + "sub x20, x20, #0x10\n" + "ldr q18, [x25], #0x10\n" + "ldr q17, [x24], #0x10\n" + "cmp x20, #0x10\n" + "ldr q0, [x9], #0x10\n" + "ldr q31, [x26], #0x10\n" + "ldr q24, [x25], #0x10\n" + "ldr q16, [x24], #0x10\n" + "ldr q23, [x9], #0x10\n" + "zip1 v30.4s, v20.4s, v18.4s\n" + "zip1 v29.4s, v19.4s, v17.4s\n" + "ldr q22, [x26], #0x10\n" + "ldr q21, [x25], #0x10\n" + "zip2 v28.4s, v20.4s, v18.4s\n" + "zip2 v27.4s, v19.4s, v17.4s\n" + "ldr q20, [x24], #0x10\n" + "ldr q19, [x9], #0x10\n" + "zip1 v26.4s, v0.4s, v24.4s\n" + "zip1 v25.4s, v31.4s, v16.4s\n" + "ldr q18, [x26], #0x10\n" + "ldr q17, [x25], #0x10\n" + "zip2 v8.4s, v0.4s, v24.4s\n" + "zip2 v24.4s, v31.4s, v16.4s\n" + "ldr q16, [x24], #0x10\n" + "zip1 v7.4s, v23.4s, v21.4s\n" + "zip1 v6.4s, v22.4s, v20.4s\n" + "zip2 v5.4s, v23.4s, v21.4s\n" + "zip2 v4.4s, v22.4s, v20.4s\n" + "zip1 v3.4s, v19.4s, v17.4s\n" + "zip1 v2.4s, v18.4s, v16.4s\n" + "zip2 v1.4s, v19.4s, v17.4s\n" + "zip2 v0.4s, v18.4s, v16.4s\n" + "zip1 v23.4s, v30.4s, v29.4s\n" + "zip1 v22.4s, v28.4s, v27.4s\n" + "zip1 v21.4s, v26.4s, v25.4s\n" + "zip1 v20.4s, v8.4s, v24.4s\n" + "zip1 v19.4s, v7.4s, v6.4s\n" + "zip1 v18.4s, v5.4s, v4.4s\n" + "zip1 v17.4s, v3.4s, v2.4s\n" + "zip1 v16.4s, v1.4s, v0.4s\n" + ".inst 0x0ea16aff // bfcvtn v31.4h, v23.4s\n" + "zip2 v30.4s, v30.4s, v29.4s\n" + ".inst 0x0ea16add // bfcvtn v29.4h, v22.4s\n" + "zip2 v28.4s, v28.4s, v27.4s\n" + ".inst 0x0ea16abb // bfcvtn v27.4h, v21.4s\n" + "zip2 v26.4s, v26.4s, v25.4s\n" + ".inst 0x0ea16a99 // bfcvtn v25.4h, v20.4s\n" + "zip2 v24.4s, v8.4s, v24.4s\n" + ".inst 0x0ea16a77 // bfcvtn v23.4h, v19.4s\n" + "zip2 v22.4s, v7.4s, v6.4s\n" + ".inst 0x0ea16a55 // bfcvtn v21.4h, v18.4s\n" + "zip2 v20.4s, v5.4s, v4.4s\n" + ".inst 0x0ea16a33 // bfcvtn v19.4h, v17.4s\n" + "zip2 v18.4s, v3.4s, v2.4s\n" + ".inst 0x0ea16a11 // bfcvtn v17.4h, v16.4s\n" + "zip2 v16.4s, v1.4s, v0.4s\n" + ".inst 0x4ea16bdf // bfcvtn2 v31.8h, v30.4s\n" + ".inst 0x4ea16b9d // bfcvtn2 v29.8h, v28.4s\n" + ".inst 0x4ea16b5b // bfcvtn2 v27.8h, v26.4s\n" + ".inst 0x4ea16b19 // bfcvtn2 v25.8h, v24.4s\n" + ".inst 0x4ea16ad7 // bfcvtn2 v23.8h, v22.4s\n" + ".inst 0x4ea16a95 // bfcvtn2 v21.8h, v20.4s\n" + ".inst 0x4ea16a53 // bfcvtn2 v19.8h, v18.4s\n" + ".inst 0x4ea16a11 // bfcvtn2 v17.8h, v16.4s\n" + "str q31, [x27, #0x0]\n" + "str q29, [x27, #0x10]\n" + "str q27, [x27, #0x20]\n" + "str q25, [x27, #0x30]\n" + "str q23, [x27, #0x40]\n" + "str q21, [x27, #0x50]\n" + "str q19, [x27, #0x60]\n" + "str q17, [x27, #0x70]\n" + "add x27, x27, #0x80\n" + "bge 18b\n" + "19:" // Tail row loop: width 16 loop: skip + "cmp x20, #0x4\n" + "blt 21f\n" + "20:" // Tail row loop: width 4 loop: loop + "ldr q21, [x9], #0x10\n" + "ldr q20, [x26], #0x10\n" + "sub x20, x20, #0x4\n" + "ldr q19, [x25], #0x10\n" + "ldr q17, [x24], #0x10\n" + "cmp x20, #0x4\n" + "zip1 v18.4s, v21.4s, v19.4s\n" + "zip1 v16.4s, v20.4s, v17.4s\n" + "zip2 v21.4s, v21.4s, v19.4s\n" + "zip2 v20.4s, v20.4s, v17.4s\n" + "zip1 v17.4s, v18.4s, v16.4s\n" + "zip2 v19.4s, v18.4s, v16.4s\n" + "zip1 v16.4s, v21.4s, v20.4s\n" + ".inst 0x0ea16a32 // bfcvtn v18.4h, v17.4s\n" + "zip2 v17.4s, v21.4s, v20.4s\n" + ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n" + ".inst 0x4ea16a72 // bfcvtn2 v18.8h, v19.4s\n" + ".inst 0x4ea16a30 // bfcvtn2 v16.8h, v17.4s\n" + "str q18, [x27, #0x0]\n" + "str q16, [x27, #0x10]\n" + "add x27, x27, #0x20\n" + "bge 20b\n" + "21:" // Tail row loop: width 4 loop: skip + "cmp x20, #0x1\n" + "blt 23f\n" + "22:" // Tail row loop: width 1 loop: loop + "ldr s19, [x9], #0x4\n" + "ldr s18, [x26], #0x4\n" + "sub x20, x20, #0x1\n" + "ldr s17, [x25], #0x4\n" + "ldr s16, [x24], #0x4\n" + "cmp x20, #0x1\n" + "zip1 v17.4s, v19.4s, v17.4s\n" + "zip1 v16.4s, v18.4s, v16.4s\n" + "zip1 v16.4s, v17.4s, v16.4s\n" + ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n" + "str d16, [x27, #0x0]\n" + "add x27, x27, #0x8\n" + "bge 22b\n" + "23:" // Tail row loop: width 1 loop: skip + "24:" // Tail row loop: odd col skip + "cmp %x[height], #0x1\n" + "add %x[out], %x[out], #0xc0\n" + "bge 15b\n" + "25:" // Done + : [bias] "+&r"(bias), [height] "+&r"(height), [in] "+&r"(in), [out] "+&r"(out) + : [in_stride] "r"(in_stride), [out_stride] "r"(out_stride), [pad_row] "r"(pad_row), [width] "r"(width) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31", "x9", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p4x24biasf32_f32_bf16_neon.h b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p4x24biasf32_f32_bf16_neon.h new file mode 100644 index 0000000000000000000000000000000000000000..1b67fc25a57c7dd4c2e16067137e28dc50ee4dbb --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p4x24biasf32_f32_bf16_neon.h @@ -0,0 +1,80 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Gets n step value. +/// +/// The starting row index must be divisible by `n_step`. +/// +/// @return The n step value. +size_t kai_get_n_step_rhs_pack_kxn_f32p4x24biasf32_f32_bf16_neon(void); + +/// Gets the offset in bytes to the data element in the RHS matrix buffer. +/// +/// @param[in] n_idx Column index. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_offset_rhs_pack_kxn_f32p4x24biasf32_f32_bf16_neon(size_t n_idx); + +/// Gets the offset in bytes to the data element in the bias buffer. +/// +/// @param[in] n_idx Column index. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_bias_offset_rhs_pack_kxn_f32p4x24biasf32_f32_bf16_neon(size_t n_idx); + +/// Gets the offset in bytes to the data element in the packed RHS buffer. +/// +/// @param[in] n_idx Row index. +/// @param[in] k Number of columns. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_packed_offset_rhs_pack_kxn_f32p4x24biasf32_f32_bf16_neon(size_t n_idx, size_t k); + +/// Gets the size in bytes of the packed RHS buffer. +/// +/// @param[in] n Number of rows. +/// @param[in] k Number of columns. +/// +/// @return The size in bytes of the packed RHS buffer. +size_t kai_get_rhs_packed_size_rhs_pack_kxn_f32p4x24biasf32_f32_bf16_neon(size_t n, size_t k); + +/// Runs the RHS packing function for matrix multiplication. +/// +/// The pointer of each buffers (RHS, bias and packed RHS) needs to be added with offset +/// calculated using the following functions: +/// +/// * RHS: @ref kai_get_rhs_offset_rhs_pack_kxn_f32p4x24biasf32_f32_bf16_neon. +/// * Bias: @ref kai_get_bias_offset_rhs_pack_kxn_f32p4x24biasf32_f32_bf16_neon. +/// * Output: @ref kai_get_rhs_packed_offset_rhs_pack_kxn_f32p4x24biasf32_f32_bf16_neon. +/// +/// @param[in] num_groups Number of groups. It must be 1. +/// @param[in] n Number of columns of the output matrix. +/// @param[in] k Common dimension between the LHS and RHS matrix. +/// @param[in] nr Block size in N dimension. It must be 24. +/// @param[in] kr Block size in K dimension. It must be 4. +/// @param[in] sr Number of kr splits. It must be 1. +/// @param[in] rhs_stride Row stride in bytes of the RHS matrix. +/// @param[in] rhs RHS matrix data buffer. +/// @param[in] bias Bias matrix data buffer. +/// @param[in] scale Scale data buffer. It must be NULL. +/// @param[out] rhs_packed Packed RHS matrix. +/// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. It must be 0. +/// @param[in] params Extra packing parameters. It must be NULL. +void kai_run_rhs_pack_kxn_f32p4x24biasf32_f32_bf16_neon( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, + const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/test/BUILD.bazel b/test/BUILD.bazel index c806b5d54fdd54ddd92082dfcf502ffbb2c3d29a..be7efde91a2354e4038805b16acc599400c86c04 100644 --- a/test/BUILD.bazel +++ b/test/BUILD.bazel @@ -53,8 +53,10 @@ kai_cxx_library( cc_test( name = "kleidiai_test", srcs = [ + "tests/matmul_clamp_f32_bf16p_bf16p_test.cpp", "tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp", "tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp", + "tests/matmul_clamp_f32_qsi8d32p_qsi4c32p_test", "tests/matmul_test.cpp", ], copts = kai_cxxopts(kai_cpu_bf16() + kai_cpu_fp16()), diff --git a/test/common/MatMulMethod.hpp b/test/common/MatMulMethod.hpp new file mode 100644 index 0000000000000000000000000000000000000000..542e33b2a384a6c1b653dbe1f8b4df23b6bd9f2f --- /dev/null +++ b/test/common/MatMulMethod.hpp @@ -0,0 +1,356 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include +#include +#include +#include + +#include "kai/kai_common.h" +#include "test/common/data_format.hpp" +#include "test/common/float16.hpp" + +namespace kai::test { + +// NOLINTBEGIN(misc-non-private-member-variables-in-classes) + +/// Matrix multiplication method. +struct MatMulMethod { + std::string_view name; ///< Name of matmul method. + + size_t m0; ///< Block size in M dimension. + size_t n0; ///< Block size in N dimension. + + bool lhs_transposed; ///< LHS matrix is transposed. + bool rhs_transposed; ///< RHS matrix is transposed. + + bool is_sme2; ///< Test is a sme2 test + + DataFormat dst_format; ///< Data format of the destination matrix. + DataFormat lhs_format; ///< Data format of the LHS matrix. + DataFormat packed_lhs_format; ///< Data format of the packed LHS matrix. + DataFormat rhs_format; ///< Data format of the RHS matrix. + DataFormat packed_rhs_format; ///< Data format of the packed RHS matrix. + DataFormat bias_format; ///< Data format of the bias vector. + + /// Gets mr value. + /// + /// This is the packing parameter which must be used to pack the LHS matrix (if necessary). + /// + /// @return The mr value. + std::function fn_get_mr; + + /// Gets nr value. + /// + /// This is the packing parameter which must be used to pack the RHS matrix (if necessary). + /// + /// @return The nr value. + std::function fn_get_nr; + + /// Gets kr value. + /// + /// This is the packing parameter which must be used to pack the LHS and RHS matrix (if necessary). + /// + /// @return The kr value. + std::function fn_get_kr; + + /// Gets sr value. + /// + /// This is the packing parameter which must be used to pack the RHS matrix. + /// + /// @return The sr value. + std::function fn_get_sr; + + /// Gets m step value for main kernel. + /// + /// The starting row index must be divisible by `m_step`. + /// + /// @return The m step value. + std::function fn_get_main_m_step; + + /// Gets n step value for RHS packing kernel. + /// + /// The starting row index must be divisible by `n_step`. + /// + /// @return The n step value. + std::function fn_get_pack_rhs_n_step; + + /// Gets n step value for main kernel. + /// + /// The starting column index must be divisible by `n_step`. + /// + /// @return The n step value. + std::function fn_get_main_n_step; + + /// Gets the offset in bytes of the LHS matrix. + /// + /// @param[in] m_idx Coordinate of the matrix in M dimension. + /// @param[in] stride Row stride in bytes. + /// + /// @return The offset in bytes. + std::function fn_get_lhs_offset; + + /// Gets the size in bytes of the packed LHS matrix. + /// + /// @param[in] m Number of rows in the unpacked LHS matrix. + /// @param[in] k Number of columns in the unpacked LHS matrix. + /// @param[in] mr Number of rows to be interleaved. + /// @param[in] kr Unused. Must be 1. + /// @param[in] sr Unused. Must be 1. + /// + /// @return The size in bytes. + std::function fn_get_packed_lhs_size; + + /// Gets the offset in bytes of the packed LHS matrix. + /// + /// @param[in] m_idx Coordinate of the matrix in M dimension. + /// @param[in] k Size of the matrix in K dimension. + /// + /// @return The offset in bytes. + std::function fn_get_packed_lhs_offset; + + /// Preprocesses the LHS matrix. + /// + /// @param[in] m Number of rows of the unpacked LHS matrix. + /// @param[in] k Common dimension between the LHS and RHS matrix. + /// @param[in] mr Block size in M dimension. It must be {{ kernel.interleave_by }}VL. + /// @param[in] kr Block size in K dimension. It must be {{ kernel.block_by }}. + /// @param[in] sr Number of kr splits. It must be 1. + /// @param[in] m_idx_start Unused. Must be 0. + /// @param[in] lhs LHS matrix data buffer. + /// @param[in] lhs_stride Row stride in bytes of the LHS matrix. + /// @param[out] lhs_packed Packed RHS matrix. + std::function + fn_pack_lhs; + + /// Gets a value indicating whether LHS packing is needed. + [[nodiscard]] bool is_pack_lhs_needed() const { + return fn_pack_lhs != nullptr; + } + + /// Gets the offset in bytes of the RHS matrix. + /// + /// @param[in] n_idx Coordinate of the matrix in N dimension. + /// + /// @return The offset in bytes. + std::function fn_get_rhs_offset; + + /// Gets the size in bytes of the packed RHS matrix. + /// + /// @param[in] n Size of the matrix in N dimension. + /// @param[in] k Size of the matrix in K dimension. + /// + /// @return The size in bytes. + std::function fn_get_packed_rhs_size; + + /// Gets the offset in bytes of the packed RHS matrix in the RHS packing kernel + /// + /// @param[in] n_idx Coordinate of the matrix in N dimension. + /// @param[in] k Size of the matrix in K dimension. + /// + /// @return The offset in bytes. + std::function fn_get_pack_rhs_packed_rhs_offset; + + /// Gets the offset in bytes of the packed RHS matrix in the main kernel. + /// + /// @param[in] n_idx Coordinate of the matrix in N dimension. + /// @param[in] k Size of the matrix in K dimension. + /// + /// @return The offset in bytes. + std::function fn_get_main_packed_rhs_offset; + + std::function + fn_pack_rhs; + + /// Gets the offset in bytes to the data element in the bias buffer. + /// + /// @param[in] n_idx Column index. + /// + /// @return The offset in bytes to the data element. + std::function fn_get_bias_offset; + + /// Gets the offset in bytes to the data element in the destination matrix buffer. + /// + /// @param[in] m_idx Row index. + /// @param[in] n_idx Column index. + /// @param[in] stride Row stride in bytes. + /// + /// @return The offset in bytes to the data element. + std::function fn_get_dst_offset; + + /// Gets the size in bytes of the destination matrix buffer. + /// + /// @param[in] m Number of rows. + /// @param[in] n Number of columns. + /// + /// @return The size in bytes of the destination matrix buffer. + std::function fn_get_dst_size; + + /// Performs F16 or F32 matrix multiplication with RHS packing + /// followed by clamp operation. + /// + /// @param[in] m Size of the matrix in M dimension. + /// @param[in] n Size of the matrix in N dimension. + /// @param[in] k Size of the matrix in K dimension. + /// @param[in] lhs LHS data buffer. + /// @param[in] packed_rhs Packed RHS data buffer. + /// @param[out] dst Output data buffer. + /// @param[in] lhs_stride LHS row stride. + /// @param[in] dst_stride_row Output row stride. + /// @param[in] dst_stride_col Output column stride. + /// @param[in] clamp_min Lower bound of the output data. + /// @param[in] clamp_max Upper bound of the output data. + std::function + fn_matmul_f16_f16_f16p; + + std::function + fn_matmul_f32_f32_f32p; + + /// Performs BF16 matrix multiplication with LHS and RHS packing + /// followed by clamp operation. + /// + /// @param[in] m Size of the matrix in M dimension. + /// @param[in] n Size of the matrix in N dimension. + /// @param[in] k Size of the matrix in K dimension. + /// @param[in] packed_lhs Packed LHS data buffer. + /// @param[in] packed_rhs Packed RHS data buffer. + /// @param[out] dst Output data buffer. + /// @param[in] dst_stride_row Output row stride. + /// @param[in] dst_stride_col Output column stride. + /// @param[in] clamp_min Lower bound of the output data. + /// @param[in] clamp_max Upper bound of the output data. + std::function + fn_matmul_f32_bf16p_bf16p; + + /// Performs BF16 matrix multiplication with RHS packing + /// followed by clamp operation. + /// + /// @param[in] m Size of the matrix in M dimension. + /// @param[in] n Size of the matrix in N dimension. + /// @param[in] k Size of the matrix in K dimension. + /// @param[in] lhs LHS data buffer. + /// @param[in] lhs_stride LHS row stride. + /// @param[in] packed_rhs Packed RHS data buffer. + /// @param[out] dst Output data buffer. + /// @param[in] dst_stride_row Output row stride. + /// @param[in] dst_stride_col Output column stride. + /// @param[in] clamp_min Lower bound of the output data. + /// @param[in] clamp_max Upper bound of the output data. + std::function + fn_matmul_f32_f32_bf16p; + + /// Performs F32 matrix multiplication with LHS & RHS packing + /// followed by clamp operation. + /// + /// @param[in] m Number of output rows to be computed. + /// @param[in] n Number of output columns to be computed. + /// @param[in] k Common dimension of the LHS and RHS operands. + /// @param[in] packed_lhs Packed LHS matrix buffer. + /// @param[in] packed_rhs Packed RHS matrix buffer. + /// @param[out] dst Output matrix buffer. + /// @param[in] dst_stride_row Row stride in bytes of the output matrix. + /// @param[in] dst_stride_col Column stride in bytes of the output matrix. + /// @param[in] clamp_min Minimum value to clamp the final result. + /// @param[in] clamp_max Maximum value to clamp the final result. + std::function + fn_matmul_f32_f32p_f32p; + + /// Gets a value indicating whether pre-processing the RHS matrix is needed. + [[nodiscard]] bool is_pack_rhs_needed() const { + return fn_pack_rhs != nullptr; + } + + /// Preprocesses the RHS matrix. + /// + /// @param[in] n Size of the matrix in N dimension. + /// @param[in] k Size of the matrix in K dimension. + /// @param[in] rhs RHS data buffer. + /// @param[in] rhs_row_stride RHS row stride. + /// @param[in] bias Bias data buffer. + /// @param[in] scale Quantization scales data buffer. + /// @param[out] packed_rhs Packed RHS data buffer. + void pack_rhs( + size_t n, size_t k, const void* rhs, size_t rhs_row_stride, const void* bias, const void* scale, + void* packed_rhs) const { + KAI_UNUSED(n); + KAI_UNUSED(k); + KAI_UNUSED(rhs); + KAI_UNUSED(rhs_row_stride); + KAI_UNUSED(bias); + KAI_UNUSED(scale); + KAI_UNUSED(packed_rhs); + + if (fn_pack_rhs != nullptr) { + fn_pack_rhs( + 1, n, k, fn_get_nr(), fn_get_kr(), fn_get_sr(), rhs_row_stride, rhs, bias, nullptr, packed_rhs, 0, + nullptr); + } else { + KAI_ERROR("RHS pre-processing is not supported!"); + } + } + + [[nodiscard]] bool has_main_kernel() const { + return fn_matmul_f16_f16_f16p != nullptr || fn_matmul_f32_f32p_f32p != nullptr || + fn_matmul_f32_f32_f32p != nullptr || fn_matmul_f32_bf16p_bf16p != nullptr || + fn_matmul_f32_f32_bf16p != nullptr; + } + + void main_kernel( + size_t m, size_t n, size_t k, const void* lhs, const void* rhs, const void* bias, void* dst, size_t lhs_stride, + size_t rhs_stride, size_t dst_stride, float clamp_min, float clamp_max) const { + KAI_UNUSED(bias); + KAI_UNUSED(rhs_stride); + + if (fn_matmul_f16_f16_f16p) { + fn_matmul_f16_f16_f16p( + m, n, k, lhs, lhs_stride, rhs, dst, dst_stride, sizeof(Float16), clamp_min, + static_cast(clamp_max)); + } else if (fn_matmul_f32_f32_f32p) { + fn_matmul_f32_f32_f32p( + m, n, k, lhs, lhs_stride, rhs, dst, dst_stride, sizeof(float), clamp_min, + static_cast(clamp_max)); + } else if (fn_matmul_f32_f32p_f32p) { + fn_matmul_f32_f32p_f32p(m, n, k, lhs, rhs, dst, dst_stride, sizeof(float), clamp_min, clamp_max); + } else if (fn_matmul_f32_bf16p_bf16p) { + fn_matmul_f32_bf16p_bf16p(m, n, k, lhs, rhs, dst, dst_stride, sizeof(float), clamp_min, clamp_max); + } else if (fn_matmul_f32_f32_bf16p) { + fn_matmul_f32_f32_bf16p( + m, n, k, lhs, lhs_stride, rhs, dst, dst_stride, sizeof(float), clamp_min, clamp_max); + } else { + KAI_ERROR("Main kernel is not available!"); + } + } +}; + +// NOLINTEND(misc-non-private-member-variables-in-classes) +} // namespace kai::test diff --git a/test/common/bfloat16.hpp b/test/common/bfloat16.hpp index 94c362e395b21e7cff4a8f94dec9092b632f3345..0c26d749058eb8d2eee720638eacd0305228c2be 100644 --- a/test/common/bfloat16.hpp +++ b/test/common/bfloat16.hpp @@ -82,6 +82,14 @@ public: return _data != rhs._data; } + uint16_t data() const { + return _data; + } + + void set_data(uint16_t data) { + _data = data; + } + /// Writes the value to the output stream. /// /// @param[in] os Output stream to be written to. diff --git a/test/common/compare.cpp b/test/common/compare.cpp index 54af776fd6937a5dd4deaa3469977a7ec1db3a83..b000f3f5efaa7e06295fd4539e7d05b021806008 100644 --- a/test/common/compare.cpp +++ b/test/common/compare.cpp @@ -213,6 +213,9 @@ bool compare( case DataType::FP16: return compare_raw(imp_data, ref_data, format, full_height, full_width, rect, handler); + case DataType::BF16: + return compare_raw(imp_data, ref_data, format, full_height, full_width, rect, handler); + default: break; } diff --git a/test/common/matmul_test_common.cpp b/test/common/matmul_test_common.cpp new file mode 100644 index 0000000000000000000000000000000000000000..905450fb31f870f6706ca7a394a3b446b7b49f84 --- /dev/null +++ b/test/common/matmul_test_common.cpp @@ -0,0 +1,25 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "matmul_test_common.hpp" + +#include +#include + +namespace kai::test { +void PrintTo(const MatMulTestParams& param, std::ostream* os) { + const auto& [method, shape, portion] = param; + + // NOLINTBEGIN(cppcoreguidelines-pro-bounds-constant-array-index) + *os << "Method_" << method.name // + << "__M_" << shape.m << "__N_" << shape.n << "__K_" << shape.k // + << "__PortionStartRow_" << static_cast(portion.start_row() * 1000) // + << "__PortionStartCol_" << static_cast(portion.start_col() * 1000) // + << "__PortionHeight_" << static_cast(portion.height() * 1000) // + << "__PortionWidth_" << static_cast(portion.width() * 1000); + // NOLINTEND(cppcoreguidelines-pro-bounds-constant-array-index) +} +} // namespace kai::test diff --git a/test/common/matmul_test_common.hpp b/test/common/matmul_test_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..767117c5d97746b7aad8b9b2392d6b9c0a16fb8c --- /dev/null +++ b/test/common/matmul_test_common.hpp @@ -0,0 +1,26 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include + +#include "test/common/MatMulMethod.hpp" +#include "test/common/matrix_portion.hpp" + +namespace kai::test { +/// Matrix multiplication shape. +struct MatMulShape { + size_t m; ///< LHS height. + size_t n; ///< RHS width. + size_t k; ///< LHS width and RHS height. +}; + +/// Matrix multiplication test information. +using MatMulTestParams = std::tuple; + +/// Prints the test information. +void PrintTo(const MatMulTestParams& param, std::ostream* os); +} // namespace kai::test diff --git a/test/common/memory.hpp b/test/common/memory.hpp index bf5fbb017ed3d4cca5895a895c795c08ff30721f..7ea0aa53249d8e2cec841319517c0bc6e5ad6293 100644 --- a/test/common/memory.hpp +++ b/test/common/memory.hpp @@ -9,6 +9,7 @@ #include #include +#include "test/common/bfloat16.hpp" #include "test/common/int4.hpp" namespace kai::test { @@ -25,6 +26,14 @@ inline constexpr size_t size_in_bits = 4; template <> inline constexpr size_t size_in_bits = 4; +/// TODO: Move this +inline float bf16_to_float(uint16_t v) { + const uint32_t lv = (v << 16); + float fp; + memcpy(&fp, &lv, sizeof(lv)); + return fp; +} + /// Reads the array at the specified index. /// /// @param[in] array Data buffer. @@ -39,6 +48,9 @@ T read_array(const void* array, size_t index) { } else if constexpr (std::is_same_v) { const auto [lo, hi] = Int4::unpack_u8(reinterpret_cast(array)[index / 2]); return index % 2 == 0 ? lo : hi; + } else if constexpr (std::is_same_v) { + uint16_t raw_value = reinterpret_cast(array)[index]; + return BFloat16(bf16_to_float(raw_value)); } else { return reinterpret_cast(array)[index]; } diff --git a/test/reference/pack.cpp b/test/reference/pack.cpp index 221ba36079ec8d6c6bf59289826a37ec2525b300..5c549664ee08894f5e8924e76febeb7b5700052f 100644 --- a/test/reference/pack.cpp +++ b/test/reference/pack.cpp @@ -14,6 +14,7 @@ #include #include "kai/kai_common.h" +#include "test/common/bfloat16.hpp" #include "test/common/data_format.hpp" #include "test/common/data_type.hpp" #include "test/common/float16.hpp" @@ -25,14 +26,20 @@ namespace kai::test { namespace { +uint16_t convert(const uint8_t* src_ptr_elm, DataType src_dtype, DataType dst_dtype) { + KAI_ASSUME(src_dtype == DataType::FP32 && dst_dtype == DataType::BF16); + return BFloat16(*reinterpret_cast(src_ptr_elm)).data(); +} + std::vector pack_block( - const void* src, size_t data_esize, size_t full_height, size_t full_width, size_t block_height, size_t block_width, - size_t subblock_height, size_t subblock_width) { + const void* src, DataType src_dtype, DataType dst_dtype, size_t src_esize, size_t dst_esize, size_t full_height, + size_t full_width, size_t block_height, size_t block_width, size_t subblock_height, size_t subblock_width) { const auto dst_bytes = - round_up_multiple(full_height, block_height) * round_up_multiple(full_width, block_width) * data_esize; + round_up_multiple(full_height, block_height) * round_up_multiple(full_width, block_width) * dst_esize; std::vector dst; dst.resize(dst_bytes); + memset(dst.data(), 0, dst_bytes); const auto* src_ptr = reinterpret_cast(src); auto* dst_ptr = dst.data(); @@ -42,18 +49,38 @@ std::vector pack_block( for (size_t y_subblock = 0; y_subblock < block_height; y_subblock += subblock_height) { for (size_t x_subblock = 0; x_subblock < block_width; x_subblock += subblock_width) { for (size_t y_element = 0; y_element < subblock_height; ++y_element) { - if (y_block + y_subblock + y_element < full_height) { - const auto len = std::min(subblock_width, full_width - x_block - x_subblock); - - memcpy( - dst_ptr, - src_ptr + - ((y_block + y_subblock + y_element) * full_width + x_block + x_subblock) * - data_esize, - len * data_esize); - } + if (src_dtype == dst_dtype) { + const size_t esize = dst_esize; + + if (y_block + y_subblock + y_element < full_height) { + const auto len = std::min(subblock_width, full_width - x_block - x_subblock); + + memcpy( + dst_ptr, + src_ptr + + ((y_block + y_subblock + y_element) * full_width + x_block + x_subblock) * + esize, + len * esize); + } - dst_ptr += subblock_width * data_esize; + dst_ptr += subblock_width * esize; + } else if (dst_esize == 2 /* 16 bits */) { + for (size_t x_element = 0; x_element < subblock_width; ++x_element) { + if (y_block + y_subblock + y_element < full_height) { + if (x_block + x_subblock + x_element < full_width) { + const uint8_t* src_ptr_elm = src_ptr + + ((y_block + y_subblock + y_element) * full_width + x_block + x_subblock + + x_element) * + src_esize; + + uint16_t src_value = convert(src_ptr_elm, src_dtype, dst_dtype); + memcpy(dst_ptr, &src_value, dst_esize); + } + } + + dst_ptr += dst_esize; + } + } } } } @@ -67,43 +94,67 @@ std::vector pack_block( /// Packs the matrix from raw to per-row bias format. std::vector pack_bias_per_row( - size_t data_esize, size_t zero_point_esize, const void* src, const void* bias, size_t height, size_t width, - size_t block_height, size_t block_width, size_t subblock_height, size_t subblock_width) { + DataType src_dtype, DataType bias_dtype, DataType dst_dtype, size_t src_esize, size_t bias_esize, size_t dst_esize, + const void* src, const void* bias, size_t height, size_t width, size_t block_height, size_t block_width, + size_t subblock_height, size_t subblock_width) { + KAI_ASSUME(src_dtype == bias_dtype); + const auto num_groups = (height + block_height - 1) / block_height; const auto group_num_blocks = (width + block_width - 1) / block_width; - - const auto group_zero_points_bytes = block_height * zero_point_esize; - const auto block_data_bytes = block_height * block_width * data_esize; - const auto group_bytes = group_zero_points_bytes + group_num_blocks * block_data_bytes; + const auto group_bias_bytes = block_height * bias_esize; + const auto block_data_bytes = block_height * block_width * dst_esize; + const auto group_bytes = group_bias_bytes + group_num_blocks * block_data_bytes; const auto dst_bytes = num_groups * group_bytes; std::vector dst; dst.resize(dst_bytes); + memset(dst.data(), 0, dst_bytes); const auto* src_ptr = reinterpret_cast(src); const auto* bias_ptr = reinterpret_cast(bias); auto* dst_ptr = dst.data(); for (size_t y_block = 0; y_block < height; y_block += block_height) { - // Packs the zero points. + // Packs the bias. const auto bias_len = std::min(block_height, height - y_block); - memcpy(dst_ptr, bias_ptr, bias_len * zero_point_esize); - bias_ptr += block_height * zero_point_esize; - dst_ptr += block_height * zero_point_esize; + memcpy(dst_ptr, bias_ptr, bias_len * bias_esize); + bias_ptr += block_height * bias_esize; + dst_ptr += block_height * bias_esize; for (size_t x_block = 0; x_block < width; x_block += block_width) { for (size_t y_subblock = 0; y_subblock < block_height; y_subblock += subblock_height) { for (size_t x_subblock = 0; x_subblock < block_width; x_subblock += subblock_width) { for (size_t y_element = 0; y_element < subblock_height; ++y_element) { - if (y_block + y_subblock + y_element < height) { - const auto len = std::min(subblock_width, width - x_block - x_subblock); - memcpy( - dst_ptr, - src_ptr + - ((y_block + y_subblock + y_element) * width + x_block + x_subblock) * data_esize, - len * data_esize); + if (src_dtype == dst_dtype) { + const size_t esize = dst_esize; + if (y_block + y_subblock + y_element < height) { + const auto len = std::min(subblock_width, width - x_block - x_subblock); + + memcpy( + dst_ptr, + src_ptr + + ((y_block + y_subblock + y_element) * width + x_block + x_subblock) * esize, + len * esize); + } + + dst_ptr += subblock_width * esize; + } else if (dst_esize == 2 /* 16 bits */) { + for (size_t x_element = 0; x_element < subblock_width; ++x_element) { + if (y_block + y_subblock + y_element < height) { + if (x_block + x_subblock + x_element < width) { + const uint8_t* src_ptr_elm = src_ptr + + ((y_block + y_subblock + y_element) * width + x_block + x_subblock + + x_element) * + src_esize; + + uint16_t src_value = convert(src_ptr_elm, src_dtype, dst_dtype); + memcpy(dst_ptr, &src_value, dst_esize); + } + } + + dst_ptr += dst_esize; + } } - dst_ptr += subblock_width * data_esize; } } } @@ -118,7 +169,7 @@ std::vector pack_bias_per_row( } // namespace std::vector pack( - const DataFormat& dst_format, const void* src, [[maybe_unused]] const void* scales, const void* zero_points, + const DataFormat& dst_format, const void* src, [[maybe_unused]] const void* scales, const void* bias, const DataFormat& src_format, size_t height, size_t width) { const auto dst_dt = dst_format.data_type(); const auto dst_qf = dst_format.pack_format(); @@ -131,27 +182,31 @@ std::vector pack( const auto subblock_width = dst_format.actual_subblock_width(width); if (src_qf == DataFormat::PackFormat::NONE && dst_qf == DataFormat::PackFormat::BIAS_PER_ROW) { - KAI_ASSUME(src_dt == dst_dt); + KAI_ASSUME((src_dt == dst_dt) || (src_dt == DataType::FP32 && dst_dt == DataType::BF16)); - const auto data_esize = data_type_size_in_bits(dst_dt); - const auto zero_point_esize = data_type_size_in_bits(dst_format.zero_point_data_type()); + const auto src_esize = data_type_size_in_bits(src_dt); + const auto dst_esize = data_type_size_in_bits(dst_dt); + const auto bias_esize = data_type_size_in_bits(dst_format.zero_point_data_type()); + const auto bias_dt = dst_format.zero_point_data_type(); - if (data_esize % 8 == 0 && zero_point_esize % 8 == 0) { - return pack_bias_per_row( - data_esize / 8, zero_point_esize / 8, src, zero_points, height, width, block_height, block_width, - subblock_height, subblock_width); - } + KAI_ASSUME(dst_esize % 8 == 0 && bias_esize % 8 == 0 && src_esize % 8 == 0); + + return pack_bias_per_row( + src_dt, bias_dt, dst_dt, src_esize / 8, bias_esize / 8, dst_esize / 8, src, bias, height, width, + block_height, block_width, subblock_height, subblock_width); } if (src_qf == DataFormat::PackFormat::NONE && dst_qf == DataFormat::PackFormat::NONE) { - KAI_ASSUME(src_dt == dst_dt); + KAI_ASSUME((src_dt == dst_dt) || (src_dt == DataType::FP32 && dst_dt == DataType::BF16)); - const auto data_esize = data_type_size_in_bits(dst_dt); + const auto dst_esize = data_type_size_in_bits(dst_dt); + const auto src_esize = data_type_size_in_bits(src_dt); - if (data_esize % 8 == 0) { - return pack_block( - src, data_esize / 8, height, width, block_height, block_width, subblock_height, subblock_width); - } + KAI_ASSUME(src_esize % 8 == 0 && dst_esize % 8 == 0); + + return pack_block( + src, src_dt, dst_dt, src_esize / 8, dst_esize / 8, height, width, block_height, block_width, + subblock_height, subblock_width); } KAI_ERROR("Unsupported operation!"); diff --git a/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..493af592ecec1b9778976cf2d4b153a87dd0dc3f --- /dev/null +++ b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp @@ -0,0 +1,328 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "kai/kai_common.h" +#include "test/common/MatMulMethod.hpp" +#include "test/common/compare.hpp" +#include "test/common/cpu_info.hpp" +#include "test/common/data_format.hpp" +#include "test/common/data_type.hpp" +#include "test/common/float16.hpp" +#include "test/common/matmul_test_common.hpp" +#include "test/common/matrix_portion.hpp" +#include "test/common/printer.hpp" +#include "test/common/sme.hpp" +#include "test/reference/fill.hpp" +#include "test/reference/matmul.hpp" +#include "test/reference/pack.hpp" + +// matmul_clamp_f32_bf16p_bf16p +#include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla.h" +#include "kai/ukernels/matmul/pack/kai_lhs_pack_f32p8x4_bf16_neon.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon.h" +namespace kai::test { + +/// List of supported matrix multiplication methods. +static const std::array matmul_methods = { + MatMulMethod{ + .name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla", + + .m0 = 8, + .n0 = 12, + + .lhs_transposed = false, + .rhs_transposed = false, + + .is_sme2 = false, + + .dst_format = DataFormat(DataType::FP32), + .lhs_format = DataFormat(DataType::FP32), + .packed_lhs_format = + DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP32, DataType::UNKNOWN, 8, 4), + .rhs_format = DataFormat(DataType::FP32), + .packed_rhs_format = DataFormat( + DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4), + .bias_format = DataFormat(DataType::FP32), + + .fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + .fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + .fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + .fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + + .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon, + .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + + .fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_f32p8x4_bf16_neon, + .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_f32p8x4_bf16_neon, + .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_lhs_pack_f32p8x4_bf16_neon, + .fn_pack_lhs = kai_run_lhs_pack_f32p8x4_bf16_neon, + + .fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon, + .fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon, + .fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon, + .fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + .fn_pack_rhs = kai_run_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon, + + .fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p4x12biasf32_f32_bf16_neon, + + .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + + .fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla, + }, +}; + +/// Matrix multiplication test fixture. +class MatMulTestBf16 : public testing::TestWithParam { +private: + /// Unique ID: m, n, k + using TestDataId = std::tuple; + +protected: + /// Cached test data that is shared between multiple test case. + struct TestData { + std::vector lhs{}; ///< LHS operand. + std::vector ref_packed_lhs{}; ///< Reference packed LHS. + std::vector rhs{}; ///< RHS operand. + std::vector rhs_scales{}; ///< RHS per-row quantization scales. + std::vector bias{}; ///< Bias. + std::vector ref_packed_rhs{}; ///< Reference packed RHS. + std::vector ref_dst{}; ///< Reference output. + }; + + /// Gets the test data for the current test case. + static const TestData& test_data() { + const auto& [method, info, portion] = GetParam(); + const TestDataId data_id{info.m, info.n, info.k, method.name}; + + // If the test data is already available, returns it. + const auto data_it = _data.find(data_id); + + if (data_it != _data.end()) { + return data_it->second; + } + + // Generates the test data. + const auto has_lhs_pack = method.packed_lhs_format.data_type() != DataType::UNKNOWN; + const auto has_rhs_pack = method.packed_rhs_format.data_type() != DataType::UNKNOWN; + const auto has_bias = method.bias_format.data_type() != DataType::UNKNOWN; + + const auto lhs_h = method.lhs_transposed ? info.k : info.m; + const auto lhs_w = method.lhs_transposed ? info.m : info.k; + auto lhs = fill_matrix_random(lhs_h, lhs_w, method.lhs_format, 0); + std::vector ref_packed_lhs; + + if (has_lhs_pack) { + ref_packed_lhs = + pack(method.packed_lhs_format, lhs.data(), nullptr, nullptr, method.lhs_format, lhs_h, lhs_w); + } + + const auto rhs_h = method.rhs_transposed ? info.n : info.k; + const auto rhs_w = method.rhs_transposed ? info.k : info.n; + auto rhs = fill_matrix_random(rhs_h, rhs_w, method.rhs_format, 1); + + std::vector rhs_scales; + if (data_type_is_quantized(method.rhs_format.data_type()) && + method.rhs_format.pack_format() == DataFormat::PackFormat::NONE) { + rhs_scales = fill_matrix_random(rhs_h, 1, DataFormat(DataType::FP32), 2); + } + + const auto bias_h = 1; + const auto bias_w = info.n; + std::vector bias; + + if (has_bias) { + bias = fill_matrix_random(bias_h, bias_w, method.bias_format, 3); + } + + std::vector packed_rhs; + if (has_rhs_pack) { + packed_rhs = matmul_pack_rhs( + rhs.data(), !rhs_scales.empty() ? rhs_scales.data() : nullptr, bias.data(), method.rhs_format, + method.packed_rhs_format, info.n, info.k, !method.rhs_transposed); + } + + KAI_ASSUME(method.lhs_format.is_raw()); + KAI_ASSUME(method.rhs_format.is_raw()); + KAI_ASSUME(method.dst_format.is_raw()); + + auto ref_dst = matmul( + lhs.data(), nullptr, nullptr, method.lhs_format.data_type(), // + rhs.data(), rhs_scales.data(), nullptr, method.rhs_format.data_type(), // + bias.data(), nullptr, nullptr, method.bias_format.data_type(), // + method.dst_format.data_type(), // + info.m, info.n, info.k, method.lhs_transposed, method.rhs_transposed); + + const auto& data = _data[data_id] = { + .lhs = std::move(lhs), + .ref_packed_lhs = std::move(ref_packed_lhs), + .rhs = std::move(rhs), + .rhs_scales = std::move(rhs_scales), + .bias = std::move(bias), + .ref_packed_rhs = std::move(packed_rhs), + .ref_dst = std::move(ref_dst), + }; + + return data; + } + +private: + // NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables) + static std::map _data; + // NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables) +}; + +// NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables) +std::map MatMulTestBf16::_data; +// NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables) + +/// Tests the output. +TEST_P(MatMulTestBf16, Output) { + const auto& [method, info, portion] = GetParam(); + const auto& data = test_data(); + + if (method.is_sme2 && !cpu_has_sme2()) { + GTEST_SKIP(); + } + + if (!method.has_main_kernel()) { + GTEST_SKIP(); + } + + const auto m_step = method.fn_get_main_m_step(); + ASSERT_EQ(m_step, method.m0); + + const auto n_step = method.fn_get_main_n_step(); + ASSERT_EQ(n_step, method.n0); + + const auto rect = portion.compute_portion(info.m, info.n, method.m0, method.n0); + + if (rect.height() == 0 || rect.width() == 0) { + GTEST_SKIP(); + } + + const auto lhs_w = method.lhs_transposed ? info.m : info.k; + const auto rhs_w = method.rhs_transposed ? info.k : info.n; + const auto bias_w = info.n; + const auto dst_w = info.n; + + const auto lhs_start_row = method.lhs_transposed ? 0 : rect.start_row(); + const auto lhs_start_col = method.lhs_transposed ? rect.start_row() : 0; + const auto lhs_stride = method.lhs_format.default_row_stride(lhs_w); + + const uint8_t* lhs_data = nullptr; + uintptr_t lhs_offset = 0; + + if (method.is_pack_lhs_needed()) { + lhs_data = data.ref_packed_lhs.data(); + + const auto ref_lhs_offset = + method.packed_lhs_format.default_offset_in_bytes(lhs_start_row, lhs_start_col, info.k); + KAI_UNUSED(ref_lhs_offset); + + lhs_offset = method.fn_get_packed_lhs_offset(lhs_start_row, info.k); + + // TODO: Check with ref_lhs_offset after fixing default_offset_in_bytes() + } else { + lhs_data = data.lhs.data(); + + lhs_offset = method.fn_get_lhs_offset(lhs_start_row, lhs_stride); + const auto ref_lhs_offset = method.lhs_format.default_offset_in_bytes(lhs_start_row, lhs_start_col, lhs_w); + ASSERT_EQ(lhs_offset, ref_lhs_offset); + } + + const auto rhs_stride = method.rhs_format.default_row_stride(rhs_w); + + const uint8_t* rhs_data = nullptr; + uintptr_t rhs_offset = 0; + + if (method.is_pack_rhs_needed()) { + const auto packed_rhs_start_row = rect.start_col(); + const auto packed_rhs_start_col = 0; + + rhs_data = data.ref_packed_rhs.data(); + + rhs_offset = method.fn_get_main_packed_rhs_offset(packed_rhs_start_row, info.k); + const auto ref_rhs_offset = + method.packed_rhs_format.default_offset_in_bytes(packed_rhs_start_row, packed_rhs_start_col, info.k); + + ASSERT_EQ(rhs_offset, ref_rhs_offset); + } else { + const auto rhs_start_row = method.rhs_transposed ? rect.start_col() : 0; + const auto rhs_start_col = method.rhs_transposed ? 0 : rect.start_col(); + + rhs_data = data.rhs.data(); + rhs_offset = method.rhs_format.default_offset_in_bytes(rhs_start_row, rhs_start_col, rhs_w); + } + + const auto* bias_data = data.bias.data(); + const auto bias_offset = method.bias_format.default_offset_in_bytes(0, rect.start_row(), bias_w); + + const auto dst_stride = method.dst_format.default_row_stride(dst_w); + const auto dst_offset = method.fn_get_dst_offset(rect.start_row(), rect.start_col(), dst_stride); + const auto ref_dst_offset = method.dst_format.default_offset_in_bytes(rect.start_row(), rect.start_col(), dst_w); + ASSERT_EQ(dst_offset, ref_dst_offset); + + const auto dst_size = method.fn_get_dst_size(info.m, info.n); + const auto ref_dst_size = method.dst_format.default_size_in_bytes(info.m, info.n); + ASSERT_EQ(dst_size, ref_dst_size); + + std::vector dst; + dst.resize(dst_size); + + method.main_kernel( + rect.height(), rect.width(), info.k, lhs_data + lhs_offset, rhs_data + rhs_offset, bias_data + bias_offset, + dst.data() + dst_offset, lhs_stride, rhs_stride, dst_stride, -std::numeric_limits::infinity(), + std::numeric_limits::infinity()); + + DefaultMismatchHandler handler(0, 0.1, 0, 0.05); + const auto success = compare(dst.data(), data.ref_dst.data(), method.dst_format, info.m, info.n, rect, handler); + + ASSERT_TRUE(success); +} + +INSTANTIATE_TEST_SUITE_P( + MatMul, MatMulTestBf16, + testing::Combine( + testing::ValuesIn(matmul_methods), + testing::Values( + MatMulShape{3, 7, 3}, // Smaller than block size + MatMulShape{12, 8, 4}, // Same block size + MatMulShape{1, 1, 1023}, // Long K + MatMulShape{1013, 1, 5}, // Long M + MatMulShape{2, 1013, 6}, // Long N + MatMulShape{13, 33, 23}, // + MatMulShape{93, 57, 89}, // + MatMulShape{256, 256, 256}, // Nice shapes + MatMulShape{257, 113, 373} // Prime numbers + ), + testing::Values( + MatrixPortion(0, 0, 1, 1), // Full matrix. + MatrixPortion(0, 0, 0.25, 0.25), // Top-left corner. + MatrixPortion(0.75, 0.75, 1, 1), // Bottom-right corner. + MatrixPortion(0.75, 0, 1, 1), // Partial rows + MatrixPortion(0.4, 0.5, 0.6, 0.8) // Somewhere Middle + )), + testing::PrintToStringParamName()); + +} // namespace kai::test diff --git a/test/tests/matmul_clamp_f32_f32_bf16p_test.cpp b/test/tests/matmul_clamp_f32_f32_bf16p_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bb1047c666adae50569d9e7896f6324799691cbd --- /dev/null +++ b/test/tests/matmul_clamp_f32_f32_bf16p_test.cpp @@ -0,0 +1,325 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "kai/kai_common.h" +#include "test/common/MatMulMethod.hpp" +#include "test/common/compare.hpp" +#include "test/common/cpu_info.hpp" +#include "test/common/data_format.hpp" +#include "test/common/data_type.hpp" +#include "test/common/float16.hpp" +#include "test/common/matmul_test_common.hpp" +#include "test/common/matrix_portion.hpp" +#include "test/common/printer.hpp" +#include "test/common/sme.hpp" +#include "test/reference/fill.hpp" +#include "test/reference/matmul.hpp" +#include "test/reference/pack.hpp" + +// matmul_clamp_f32_f32_bf16p +#include "kai/ukernels/matmul/matmul_clamp_f32_f32_bf16p/kai_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p4x24biasf32_f32_bf16_neon.h" + +namespace kai::test { + +/// List of supported matrix multiplication methods. +static const std::array matmul_methods = { + MatMulMethod{ + .name = "matmul_nt_nt_f32_f32_bf16p_4x24_neon_mmla", + + .m0 = 4, + .n0 = 24, + + .lhs_transposed = false, + .rhs_transposed = false, + + .is_sme2 = false, + + .dst_format = DataFormat(DataType::FP32), + .lhs_format = DataFormat(DataType::FP32), + .packed_lhs_format = DataFormat(DataType::UNKNOWN), + .rhs_format = DataFormat(DataType::FP32), + .packed_rhs_format = DataFormat( + DataType::BF16, 24, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 24, 4), + .bias_format = DataFormat(DataType::FP32), + + .fn_get_mr = kai_get_mr_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla, + .fn_get_nr = kai_get_nr_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla, + .fn_get_kr = kai_get_kr_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla, + .fn_get_sr = kai_get_sr_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla, + + .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla, + .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p4x24biasf32_f32_bf16_neon, + .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla, + + .fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla, + .fn_get_packed_lhs_size = nullptr, + .fn_get_packed_lhs_offset = nullptr, + .fn_pack_lhs = nullptr, + + .fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p4x24biasf32_f32_bf16_neon, + .fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p4x24biasf32_f32_bf16_neon, + .fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_f32p4x24biasf32_f32_bf16_neon, + .fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla, + .fn_pack_rhs = kai_run_rhs_pack_kxn_f32p4x24biasf32_f32_bf16_neon, + + .fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p4x24biasf32_f32_bf16_neon, + + .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla, + .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla, + + .fn_matmul_f32_f32_bf16p = kai_run_matmul_clamp_f32_f32_bf16p4x1biasf32_4x24x4_neon_mmla, + }, +}; + +/// Matrix multiplication test fixture. +class MatMulTestHybridBf16 : public testing::TestWithParam { +private: + /// Unique ID: m, n, k, method name + using TestDataId = std::tuple; + +protected: + /// Cached test data that is shared between multiple test case. + struct TestData { + std::vector lhs{}; ///< LHS operand. + std::vector ref_packed_lhs{}; ///< Reference packed LHS. + std::vector rhs{}; ///< RHS operand. + std::vector rhs_scales{}; ///< RHS per-row quantization scales. + std::vector bias{}; ///< Bias. + std::vector ref_packed_rhs{}; ///< Reference packed RHS. + std::vector ref_dst{}; ///< Reference output. + }; + + /// Gets the test data for the current test case. + static const TestData& test_data() { + const auto& [method, info, portion] = GetParam(); + const TestDataId data_id{info.m, info.n, info.k, method.name}; + + // If the test data is already available, returns it. + const auto data_it = _data.find(data_id); + + if (data_it != _data.end()) { + return data_it->second; + } + + // Generates the test data. + const auto has_lhs_pack = method.packed_lhs_format.data_type() != DataType::UNKNOWN; + const auto has_rhs_pack = method.packed_rhs_format.data_type() != DataType::UNKNOWN; + const auto has_bias = method.bias_format.data_type() != DataType::UNKNOWN; + + const auto lhs_h = method.lhs_transposed ? info.k : info.m; + const auto lhs_w = method.lhs_transposed ? info.m : info.k; + auto lhs = fill_matrix_random(lhs_h, lhs_w, method.lhs_format, 0); + std::vector ref_packed_lhs; + + if (has_lhs_pack) { + ref_packed_lhs = + pack(method.packed_lhs_format, lhs.data(), nullptr, nullptr, method.lhs_format, lhs_h, lhs_w); + } + + const auto rhs_h = method.rhs_transposed ? info.n : info.k; + const auto rhs_w = method.rhs_transposed ? info.k : info.n; + auto rhs = fill_matrix_random(rhs_h, rhs_w, method.rhs_format, 1); + + std::vector rhs_scales; + if (data_type_is_quantized(method.rhs_format.data_type()) && + method.rhs_format.pack_format() == DataFormat::PackFormat::NONE) { + rhs_scales = fill_matrix_random(rhs_h, 1, DataFormat(DataType::FP32), 2); + } + + const auto bias_h = 1; + const auto bias_w = info.n; + std::vector bias; + + if (has_bias) { + bias = fill_matrix_random(bias_h, bias_w, method.bias_format, 3); + } + + std::vector packed_rhs; + if (has_rhs_pack) { + packed_rhs = matmul_pack_rhs( + rhs.data(), !rhs_scales.empty() ? rhs_scales.data() : nullptr, bias.data(), method.rhs_format, + method.packed_rhs_format, info.n, info.k, !method.rhs_transposed); + } + + KAI_ASSUME(method.lhs_format.is_raw()); + KAI_ASSUME(method.rhs_format.is_raw()); + KAI_ASSUME(method.dst_format.is_raw()); + + auto ref_dst = matmul( + lhs.data(), nullptr, nullptr, method.lhs_format.data_type(), // + rhs.data(), rhs_scales.data(), nullptr, method.rhs_format.data_type(), // + bias.data(), nullptr, nullptr, method.bias_format.data_type(), // + method.dst_format.data_type(), // + info.m, info.n, info.k, method.lhs_transposed, method.rhs_transposed); + + const auto& data = _data[data_id] = { + .lhs = std::move(lhs), + .ref_packed_lhs = std::move(ref_packed_lhs), + .rhs = std::move(rhs), + .rhs_scales = std::move(rhs_scales), + .bias = std::move(bias), + .ref_packed_rhs = std::move(packed_rhs), + .ref_dst = std::move(ref_dst), + }; + + return data; + } + +private: + // NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables) + static std::map _data; + // NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables) +}; + +// NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables) +std::map MatMulTestHybridBf16::_data; +// NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables) + +/// Tests the output. +TEST_P(MatMulTestHybridBf16, Output) { + const auto& [method, info, portion] = GetParam(); + const auto& data = test_data(); + + if (method.is_sme2 && !cpu_has_sme2()) { + GTEST_SKIP(); + } + + ASSERT_TRUE(method.has_main_kernel()); + + const auto m_step = method.fn_get_main_m_step(); + ASSERT_EQ(m_step, method.m0); + + const auto n_step = method.fn_get_main_n_step(); + ASSERT_EQ(n_step, method.n0); + + const auto rect = portion.compute_portion(info.m, info.n, method.m0, method.n0); + + if (rect.height() == 0 || rect.width() == 0) { + GTEST_SKIP(); + } + + const auto lhs_w = method.lhs_transposed ? info.m : info.k; + const auto rhs_w = method.rhs_transposed ? info.k : info.n; + const auto bias_w = info.n; + const auto dst_w = info.n; + + const auto lhs_start_row = method.lhs_transposed ? 0 : rect.start_row(); + const auto lhs_start_col = method.lhs_transposed ? rect.start_row() : 0; + const auto lhs_stride = method.lhs_format.default_row_stride(lhs_w); + + const uint8_t* lhs_data = nullptr; + uintptr_t lhs_offset = 0; + + if (method.is_pack_lhs_needed()) { + lhs_data = data.ref_packed_lhs.data(); + + const auto ref_lhs_offset = + method.packed_lhs_format.default_offset_in_bytes(lhs_start_row, lhs_start_col, info.k); + KAI_UNUSED(ref_lhs_offset); + + lhs_offset = method.fn_get_packed_lhs_offset(lhs_start_row, info.k); + + // TODO: Check with ref_lhs_offset after fixing default_offset_in_bytes() + } else { + lhs_data = data.lhs.data(); + + lhs_offset = method.fn_get_lhs_offset(lhs_start_row, lhs_stride); + const auto ref_lhs_offset = method.lhs_format.default_offset_in_bytes(lhs_start_row, lhs_start_col, lhs_w); + ASSERT_EQ(lhs_offset, ref_lhs_offset); + } + + const auto rhs_stride = method.rhs_format.default_row_stride(rhs_w); + + const uint8_t* rhs_data = nullptr; + uintptr_t rhs_offset = 0; + + if (method.is_pack_rhs_needed()) { + const auto packed_rhs_start_row = rect.start_col(); + const auto packed_rhs_start_col = 0; + + rhs_data = data.ref_packed_rhs.data(); + + rhs_offset = method.fn_get_main_packed_rhs_offset(packed_rhs_start_row, info.k); + const auto ref_rhs_offset = + method.packed_rhs_format.default_offset_in_bytes(packed_rhs_start_row, packed_rhs_start_col, info.k); + + ASSERT_EQ(rhs_offset, ref_rhs_offset); + } else { + const auto rhs_start_row = method.rhs_transposed ? rect.start_col() : 0; + const auto rhs_start_col = method.rhs_transposed ? 0 : rect.start_col(); + + rhs_data = data.rhs.data(); + rhs_offset = method.rhs_format.default_offset_in_bytes(rhs_start_row, rhs_start_col, rhs_w); + } + + const auto* bias_data = data.bias.data(); + const auto bias_offset = method.bias_format.default_offset_in_bytes(0, rect.start_row(), bias_w); + + const auto dst_stride = method.dst_format.default_row_stride(dst_w); + const auto dst_offset = method.fn_get_dst_offset(rect.start_row(), rect.start_col(), dst_stride); + const auto ref_dst_offset = method.dst_format.default_offset_in_bytes(rect.start_row(), rect.start_col(), dst_w); + ASSERT_EQ(dst_offset, ref_dst_offset); + + const auto dst_size = method.fn_get_dst_size(info.m, info.n); + const auto ref_dst_size = method.dst_format.default_size_in_bytes(info.m, info.n); + ASSERT_EQ(dst_size, ref_dst_size); + + std::vector dst; + dst.resize(dst_size); + + method.main_kernel( + rect.height(), rect.width(), info.k, lhs_data + lhs_offset, rhs_data + rhs_offset, bias_data + bias_offset, + dst.data() + dst_offset, lhs_stride, rhs_stride, dst_stride, -std::numeric_limits::infinity(), + std::numeric_limits::infinity()); + + DefaultMismatchHandler handler(0, 0.1, 0, 0.05); + const auto success = compare(dst.data(), data.ref_dst.data(), method.dst_format, info.m, info.n, rect, handler); + + ASSERT_TRUE(success); +} + +INSTANTIATE_TEST_SUITE_P( + MatMul, MatMulTestHybridBf16, + testing::Combine( + testing::ValuesIn(matmul_methods), + testing::Values( + MatMulShape{3, 7, 3}, // Smaller than block size + MatMulShape{4, 24, 4}, // Same block size + MatMulShape{1, 1, 1023}, // Long K + MatMulShape{1013, 1, 5}, // Long M + MatMulShape{2, 1013, 6}, // Long N + MatMulShape{13, 33, 23}, // + MatMulShape{93, 57, 89}, // + MatMulShape{256, 256, 256}, // Nice shapes + MatMulShape{257, 113, 373} // Prime numbers + ), + testing::Values( + MatrixPortion(0, 0, 1, 1), // Full matrix. + MatrixPortion(0, 0, 0.25, 0.25), // Top-left corner. + MatrixPortion(0.75, 0.75, 1, 1), // Bottom-right corner. + MatrixPortion(0.75, 0, 1, 1), // Partial rows + MatrixPortion(0.4, 0.5, 0.6, 0.8) // Somewhere Middle + )), + testing::PrintToStringParamName()); + +} // namespace kai::test diff --git a/test/tests/matmul_test.cpp b/test/tests/matmul_test.cpp index 3e68bac08aab92098c71e356cc4bc040c3b38939..2e9e4450fbf8552da84f7b7b6669c270c6877c63 100644 --- a/test/tests/matmul_test.cpp +++ b/test/tests/matmul_test.cpp @@ -21,11 +21,13 @@ #include #include "kai/kai_common.h" +#include "test/common/MatMulMethod.hpp" #include "test/common/compare.hpp" #include "test/common/cpu_info.hpp" #include "test/common/data_format.hpp" #include "test/common/data_type.hpp" #include "test/common/float16.hpp" +#include "test/common/matmul_test_common.hpp" #include "test/common/matrix_portion.hpp" #include "test/common/printer.hpp" #include "test/common/sme.hpp" @@ -51,293 +53,6 @@ namespace kai::test { -// NOLINTBEGIN(misc-non-private-member-variables-in-classes) - -/// Matrix multiplication method. -struct MatMulMethod { - std::string_view name; ///< Name of matmul method. - - size_t m0; ///< Block size in M dimension. - size_t n0; ///< Block size in N dimension. - - bool lhs_transposed; ///< LHS matrix is transposed. - bool rhs_transposed; ///< RHS matrix is transposed. - - bool is_sme2; ///< Test is a sme2 test - - DataFormat dst_format; ///< Data format of the destination matrix. - DataFormat lhs_format; ///< Data format of the LHS matrix. - DataFormat packed_lhs_format; ///< Data format of the packed LHS matrix. - DataFormat rhs_format; ///< Data format of the RHS matrix. - DataFormat packed_rhs_format; ///< Data format of the packed RHS matrix. - DataFormat bias_format; ///< Data format of the bias vector. - - /// Gets mr value. - /// - /// This is the packing parameter which must be used to pack the LHS matrix (if necessary). - /// - /// @return The mr value. - std::function fn_get_mr; - - /// Gets nr value. - /// - /// This is the packing parameter which must be used to pack the RHS matrix (if necessary). - /// - /// @return The nr value. - std::function fn_get_nr; - - /// Gets kr value. - /// - /// This is the packing parameter which must be used to pack the LHS and RHS matrix (if necessary). - /// - /// @return The kr value. - std::function fn_get_kr; - - /// Gets sr value. - /// - /// This is the packing parameter which must be used to pack the RHS matrix. - /// - /// @return The sr value. - std::function fn_get_sr; - - /// Gets m step value for main kernel. - /// - /// The starting row index must be divisible by `m_step`. - /// - /// @return The m step value. - std::function fn_get_main_m_step; - - /// Gets n step value for RHS packing kernel. - /// - /// The starting row index must be divisible by `n_step`. - /// - /// @return The n step value. - std::function fn_get_pack_rhs_n_step; - - /// Gets n step value for main kernel. - /// - /// The starting column index must be divisible by `n_step`. - /// - /// @return The n step value. - std::function fn_get_main_n_step; - - /// Gets the offset in bytes of the LHS matrix. - /// - /// @param[in] m_idx Coordinate of the matrix in M dimension. - /// @param[in] stride Row stride in bytes. - /// - /// @return The offset in bytes. - std::function fn_get_lhs_offset; - - /// Gets the size in bytes of the packed LHS matrix. - /// - /// @param[in] m Number of rows in the unpacked LHS matrix. - /// @param[in] k Number of columns in the unpacked LHS matrix. - /// @param[in] mr Number of rows to be interleaved. - /// @param[in] kr Unused. Must be 1. - /// @param[in] sr Unused. Must be 1. - /// - /// @return The size in bytes. - std::function fn_get_packed_lhs_size; - - /// Gets the offset in bytes of the packed LHS matrix. - /// - /// @param[in] m_idx Coordinate of the matrix in M dimension. - /// @param[in] k Size of the matrix in K dimension. - /// - /// @return The offset in bytes. - std::function fn_get_packed_lhs_offset; - - /// Preprocesses the LHS matrix. - /// - /// @param[in] m Number of rows of the unpacked LHS matrix. - /// @param[in] k Common dimension between the LHS and RHS matrix. - /// @param[in] mr Block size in M dimension. It must be {{ kernel.interleave_by }}VL. - /// @param[in] kr Block size in K dimension. It must be {{ kernel.block_by }}. - /// @param[in] sr Number of kr splits. It must be 1. - /// @param[in] m_idx_start Unused. Must be 0. - /// @param[in] lhs LHS matrix data buffer. - /// @param[in] lhs_stride Row stride in bytes of the LHS matrix. - /// @param[out] lhs_packed Packed RHS matrix. - std::function - fn_pack_lhs; - - /// Gets a value indicating whether LHS packing is needed. - [[nodiscard]] bool is_pack_lhs_needed() const { - return fn_pack_lhs != nullptr; - } - - /// Gets the offset in bytes of the RHS matrix. - /// - /// @param[in] n_idx Coordinate of the matrix in N dimension. - /// - /// @return The offset in bytes. - std::function fn_get_rhs_offset; - - /// Gets the size in bytes of the packed RHS matrix. - /// - /// @param[in] n Size of the matrix in N dimension. - /// @param[in] k Size of the matrix in K dimension. - /// - /// @return The size in bytes. - std::function fn_get_packed_rhs_size; - - /// Gets the offset in bytes of the packed RHS matrix in the RHS packing kernel - /// - /// @param[in] n_idx Coordinate of the matrix in N dimension. - /// @param[in] k Size of the matrix in K dimension. - /// - /// @return The offset in bytes. - std::function fn_get_pack_rhs_packed_rhs_offset; - - /// Gets the offset in bytes of the packed RHS matrix in the main kernel. - /// - /// @param[in] n_idx Coordinate of the matrix in N dimension. - /// @param[in] k Size of the matrix in K dimension. - /// - /// @return The offset in bytes. - std::function fn_get_main_packed_rhs_offset; - - std::function - fn_pack_rhs; - - /// Gets the offset in bytes to the data element in the bias buffer. - /// - /// @param[in] n_idx Column index. - /// - /// @return The offset in bytes to the data element. - std::function fn_get_bias_offset; - - /// Gets the offset in bytes to the data element in the destination matrix buffer. - /// - /// @param[in] m_idx Row index. - /// @param[in] n_idx Column index. - /// @param[in] stride Row stride in bytes. - /// - /// @return The offset in bytes to the data element. - std::function fn_get_dst_offset; - - /// Gets the size in bytes of the destination matrix buffer. - /// - /// @param[in] m Number of rows. - /// @param[in] n Number of columns. - /// - /// @return The size in bytes of the destination matrix buffer. - std::function fn_get_dst_size; - - /// Performs F16 or F32 matrix multiplication with RHS packing - /// followed by clamp operation. - /// - /// @param[in] m Size of the matrix in M dimension. - /// @param[in] n Size of the matrix in N dimension. - /// @param[in] k Size of the matrix in K dimension. - /// @param[in] lhs LHS data buffer. - /// @param[in] packed_rhs Packed RHS data buffer. - /// @param[out] dst Output data buffer. - /// @param[in] lhs_stride LHS row stride. - /// @param[in] dst_stride Output row stride. - /// @param[in] clamp_min Lower bound of the output data. - /// @param[in] clamp_max Upper bound of the output data. - std::function - fn_matmul_f16_f16_f16p; - - std::function - fn_matmul_f32_f32_f32p; - - /// Performs F32 matrix multiplication with LHS & RHS packing - /// followed by clamp operation. - /// - /// @param[in] m Number of output rows to be computed. - /// @param[in] n Number of output columns to be computed. - /// @param[in] k Common dimension of the LHS and RHS operands. - /// @param[in] packed_lhs Packed LHS matrix buffer. - /// @param[in] packed_rhs Packed RHS matrix buffer. - /// @param[out] dst Output matrix buffer. - /// @param[in] dst_stride_row Row stride in bytes of the output matrix. - /// @param[in] dst_stride_col Column stride in bytes of the output matrix. - /// @param[in] clamp_min Minimum value to clamp the final result. - /// @param[in] clamp_max Maximum value to clamp the final result. - std::function - fn_matmul_f32_f32p_f32p; - - /// Gets a value indicating whether pre-processing the RHS matrix is needed. - [[nodiscard]] bool is_pack_rhs_needed() const { - return fn_pack_rhs != nullptr; - } - - /// Preprocesses the RHS matrix. - /// - /// @param[in] n Size of the matrix in N dimension. - /// @param[in] k Size of the matrix in K dimension. - /// @param[in] rhs RHS data buffer. - /// @param[in] rhs_row_stride RHS row stride. - /// @param[in] bias Bias data buffer. - /// @param[in] scale Quantization scales data buffer. - /// @param[out] packed_rhs Packed RHS data buffer. - void pack_rhs( - size_t n, size_t k, const void* rhs, size_t rhs_row_stride, const void* bias, const void* scale, - void* packed_rhs) const { - KAI_UNUSED(n); - KAI_UNUSED(k); - KAI_UNUSED(rhs); - KAI_UNUSED(rhs_row_stride); - KAI_UNUSED(bias); - KAI_UNUSED(scale); - KAI_UNUSED(packed_rhs); - - if (fn_pack_rhs != nullptr) { - fn_pack_rhs( - 1, n, k, fn_get_nr(), fn_get_kr(), fn_get_sr(), rhs_row_stride, rhs, bias, nullptr, packed_rhs, 0, - nullptr); - } else { - KAI_ERROR("RHS pre-processing is not supported!"); - } - } - - [[nodiscard]] bool has_main_kernel() const { - return fn_matmul_f16_f16_f16p != nullptr || fn_matmul_f32_f32p_f32p != nullptr || - fn_matmul_f32_f32_f32p != nullptr; - } - - void main_kernel( - size_t m, size_t n, size_t k, const void* lhs, const void* rhs, const void* bias, void* dst, size_t lhs_stride, - size_t rhs_stride, size_t dst_stride, float clamp_min, float clamp_max) const { - KAI_UNUSED(bias); - KAI_UNUSED(rhs_stride); - if (fn_matmul_f16_f16_f16p) { - fn_matmul_f16_f16_f16p( - m, n, k, lhs, lhs_stride, rhs, dst, dst_stride, sizeof(Float16), clamp_min, - static_cast(clamp_max)); - } else if (fn_matmul_f32_f32_f32p) { - fn_matmul_f32_f32_f32p( - m, n, k, lhs, lhs_stride, rhs, dst, dst_stride, sizeof(float), clamp_min, - static_cast(clamp_max)); - } else if (fn_matmul_f32_f32p_f32p) { - fn_matmul_f32_f32p_f32p(m, n, k, lhs, rhs, dst, dst_stride, sizeof(float), clamp_min, clamp_max); - } else { - KAI_ERROR("Main kernel is not available!"); - } - } -}; - -// NOLINTEND(misc-non-private-member-variables-in-classes) - /// List of supported matrix multiplication methods. static const std::array matmul_methods = { MatMulMethod{ @@ -589,35 +304,11 @@ static const std::array matmul_methods = { }, }; -/// Matrix multiplication shape. -struct MatMulShape { - size_t m; ///< LHS height. - size_t n; ///< RHS width. - size_t k; ///< LHS width and RHS height. -}; - -/// Matrix multiplication test information. -using MatMulTestParams = std::tuple; - -/// Prints the test information. -void PrintTo(const MatMulTestParams& param, std::ostream* os) { - const auto& [method_no, shape, portion] = param; - - // NOLINTBEGIN(cppcoreguidelines-pro-bounds-constant-array-index) - *os << "Method_" << matmul_methods[method_no].name // - << "__M_" << shape.m << "__N_" << shape.n << "__K_" << shape.k // - << "__PortionStartRow_" << static_cast(portion.start_row() * 1000) // - << "__PortionStartCol_" << static_cast(portion.start_col() * 1000) // - << "__PortionHeight_" << static_cast(portion.height() * 1000) // - << "__PortionWidth_" << static_cast(portion.width() * 1000); - // NOLINTEND(cppcoreguidelines-pro-bounds-constant-array-index) -} - /// Matrix multiplication test fixture. class MatMulTest : public testing::TestWithParam { private: /// Unique ID: m, n, k, method_id. - using TestDataId = std::tuple; + using TestDataId = std::tuple; protected: /// Cached test data that is shared between multiple test case. @@ -633,8 +324,8 @@ protected: /// Gets the test data for the current test case. static const TestData& test_data() { - const auto& [method_no, info, portion] = GetParam(); - const TestDataId data_id{info.m, info.n, info.k, method_no}; + const auto& [method, info, portion] = GetParam(); + const TestDataId data_id{info.m, info.n, info.k, method.name}; // If the test data is already available, returns it. const auto data_it = _data.find(data_id); @@ -644,8 +335,6 @@ protected: } // Generates the test data. - const auto& method = matmul_methods.at(method_no); - const auto has_lhs_pack = method.packed_lhs_format.data_type() != DataType::UNKNOWN; const auto has_rhs_pack = method.packed_rhs_format.data_type() != DataType::UNKNOWN; const auto has_bias = method.bias_format.data_type() != DataType::UNKNOWN; @@ -720,9 +409,8 @@ std::map MatMulTest::_data; /// Tests the LHS packing kernel. TEST_P(MatMulTest, PackedLhs) { - const auto& [method_no, info, portion] = GetParam(); + const auto& [method, info, portion] = GetParam(); const auto& data = test_data(); - const auto& method = matmul_methods.at(method_no); if (method.is_sme2 && !cpu_has_sme2()) { GTEST_SKIP(); @@ -771,9 +459,8 @@ TEST_P(MatMulTest, PackedLhs) { /// Tests the RHS packing kernel. TEST_P(MatMulTest, PackedRhs) { - const auto& [method_no, info, portion] = GetParam(); + const auto& [method, info, portion] = GetParam(); const auto& data = test_data(); - const auto& method = matmul_methods.at(method_no); if (method.is_sme2 && !cpu_has_sme2()) { GTEST_SKIP(); @@ -842,9 +529,8 @@ TEST_P(MatMulTest, PackedRhs) { /// Tests the output. TEST_P(MatMulTest, Output) { - const auto& [method_no, info, portion] = GetParam(); + const auto& [method, info, portion] = GetParam(); const auto& data = test_data(); - const auto& method = matmul_methods.at(method_no); if (method.is_sme2 && !cpu_has_sme2()) { GTEST_SKIP(); @@ -940,7 +626,7 @@ TEST_P(MatMulTest, Output) { INSTANTIATE_TEST_SUITE_P( MatMul, MatMulTest, testing::Combine( - testing::Range(0, matmul_methods.size()), + testing::ValuesIn(matmul_methods), testing::Values( MatMulShape{6, 16, 32}, // MatMulShape{12, 32, 17}, //