diff --git a/BUILD.bazel b/BUILD.bazel index c4abdc25d53561cb9f090a6554aea430ab0d0536..7dbe516642231cffa60d83f74154302cb4fb92cc 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates # # SPDX-License-Identifier: Apache-2.0 # @@ -36,6 +36,7 @@ kai_c_library( name = "kleidiai", visibility = ["//visibility:public"], deps = [ + "//kai/ukernels/dw_conv", "//kai/ukernels/matmul", ], ) diff --git a/CMakeLists.txt b/CMakeLists.txt index 6fe93eff1b7fd3558bfac5c9252950dd0a46999d..d4adf59c41fe09e7a0e2392a7167b0ffcd576ad3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -310,6 +310,7 @@ set(KLEIDIAI_FILES_SME2_ASM kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot_asm.S kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_asm.S + kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla_asm.S ) set(KLEIDIAI_FILES_SME2 @@ -319,6 +320,8 @@ set(KLEIDIAI_FILES_SME2 kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c + kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.c + kai/ukernels/dw_conv/pack/kai_rhs_dw_conv_pack_x32p1vl_x32_sme.c kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c ) @@ -432,6 +435,7 @@ if(KLEIDIAI_BUILD_TESTS) test/reference/reduce.cpp test/reference/reorder.cpp test/reference/transpose.cpp + test/reference/depthwise_conv.cpp ) target_compile_options(kleidiai_test_framework @@ -467,6 +471,7 @@ if(KLEIDIAI_BUILD_TESTS) test/tests/buffer_test.cpp test/tests/float16_test.cpp test/tests/imatmul_test.cpp + test/tests/depthwise_planar_test.cpp test/tests/matmul_clamp_bf16_qai8dxp_qsi4cxp_test.cpp test/tests/matmul_clamp_f16_bf16p_bf16p_test.cpp test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp diff --git a/examples/dw_conv_f32_f32_f32p_planar_sme2/CMakeLists.txt b/examples/dw_conv_f32_f32_f32p_planar_sme2/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..b58eff3f802ce32098691c243de6d1ba3f79f345 --- /dev/null +++ b/examples/dw_conv_f32_f32_f32p_planar_sme2/CMakeLists.txt @@ -0,0 +1,32 @@ +# +# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +# +# SPDX-License-Identifier: Apache-2.0 +# + +cmake_minimum_required(VERSION 3.16) +enable_language(ASM) + +project(dw_conv_f32_f32_f32p_planar_sme2) + +set(CMAKE_CXX_STANDARD 17) +set(KAI_PATH ${CMAKE_CURRENT_SOURCE_DIR}/../../) +set(KAI_BUILD ${KAI_PATH}/build) + +include_directories(${KAI_PATH}) + +# Files requires to build the executable +add_executable( + dw_conv_f32_f32_f32p_planar_sme2 dconv.cpp + "${KAI_PATH}/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla_asm.S" + "${KAI_PATH}/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.c" + "${KAI_PATH}/kai/ukernels/dw_conv/pack/kai_rhs_dw_conv_pack_x32p1vl_x32_sme.c" + ) + +target_compile_options(dw_conv_f32_f32_f32p_planar_sme2 + PRIVATE "-march=armv8.2-a+sve+sve2;-fno-tree-vectorize" +) + +target_compile_definitions(dw_conv_f32_f32_f32p_planar_sme2 + PRIVATE $<$:KAI_DEBUG> +) diff --git a/examples/dw_conv_f32_f32_f32p_planar_sme2/dconv.cpp b/examples/dw_conv_f32_f32_f32p_planar_sme2/dconv.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2203b7947093aa64549322f353e932bc3b63e865 --- /dev/null +++ b/examples/dw_conv_f32_f32_f32p_planar_sme2/dconv.cpp @@ -0,0 +1,335 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "kai/kai_common.h" +#include "kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.h" +#include "kai/ukernels/dw_conv/pack/kai_rhs_dw_conv_pack_x32p1vl_x32_sme.h" + +using VEC_TYPE = std::vector; + +namespace { +constexpr float clamp_min = std::numeric_limits::lowest(); +constexpr float clamp_max = std::numeric_limits::max(); + +struct Padding2D { + size_t left = 0; + size_t right = 0; + size_t bottom = 0; + size_t top = 0; +}; + +struct Shape { + size_t n = 1; + size_t h = 1; + size_t w = 1; + size_t c = 1; + + [[nodiscard]] auto size() const -> size_t { + return n * h * w * c; + } + + friend std::ostream& operator<<(std::ostream& os, const Shape& shape) { + os << " [ " << shape.n << " , " << shape.h << " ," << shape.w << " , " << shape.c << " ] "; + return os; + } + + constexpr const std::size_t& operator[](std::size_t idx) const { + switch (idx) { + case 0: + return n; + case 1: + return h; + case 2: + return w; + case 3: + return c; + default: + throw std::out_of_range("Shape-index out of range (0-3)"); + } + } +}; + +#ifdef KAI_DEBUG +void print_tensor(const Shape& shape, const char* name, const float* src) { + std::cout << "\n\n" << name << " = [\n"; + for (size_t n = 0; n < shape.n; n++) { + std::cout << "\n"; + for (size_t y = 0; y < shape.h; ++y) { + std::cout << " ["; + for (size_t x = 0; x < shape.w; x++) { + std::cout << "["; + for (size_t c = 0; c < shape.c; c++) { + if (c != 0) std::cout << " , "; + std::cout << std::setprecision(3) << std::fixed + << src[n * shape.h * shape.w * shape.c + y * shape.w * shape.c + x * shape.c + c]; + } + std::cout << "] "; + } + std::cout << ("],\n"); + } + } + std::cout << ("]\n\n"); +} + +void print_raw(const Shape& shape, const char* name, const VEC_TYPE& src) { + std::cout << "\n\n" << name << " = ["; + for (size_t i = 0; i < shape.size(); i++) { + if (i != 0) std::cout << " , "; + std::cout << std::setprecision(1) << std::fixed << (float)src[i]; + } + std::cout << "]\n"; +} + +#endif +/// Fills the matrix with incremental values according to the provided weight. +/// @param[in] size Total number of elements to fill in passed vector;. +/// @param[in] dst Vector representing a tensor to fill. +/// @param[in] weight A weight value to increment by. +void fill_matrix(size_t size, VEC_TYPE& dst, const float weight) { + for (size_t i = 0; i < size; i++) { + dst[i] = float((10 * i) * weight); + } +} + +void fill_matrix_uniform(size_t size, VEC_TYPE& dst, const float weight) { + for (size_t i = 0; i < size; i++) { + dst[i] = float(weight); + } +} + +/// Depthwise Convolution - Expects NHWC dataformat. Padding value is 0. +/// +/// @tparam T Data type. +/// +/// @param[in] batches Batch dimension of feature map. +/// @param[in] in_height height of feature map. +/// @param[in] in_width width of feature map. +/// @param[in] channels Number of channels in feature map. +/// @param[in] filter_height Height dimension in filter. +/// @param[in] filter_width Width of convolution filter. +/// @param[in] feature_map Ptr to start of feature map. +/// @param[in] weights Ptr to start of weights buffer/tensor. +/// @param[in] bias Ptr to start of bias buffer. +/// @param[in] clamp_min float value to clamp output to (lower bound). +/// @param[in] clamp_max float value to clamp output to (upper bound). +/// +/// @return The result data buffer. +template +void depthwise_reference( + const size_t batches, const size_t in_height, const size_t in_width, const size_t channels, + const size_t filter_height, const size_t filter_width, const void* feature_map, const void* weights, + const void* bias, void* out, float clamp_min, float clamp_max, const Padding2D pad) { + // Calculate output dims (Padding = Valid). + const size_t out_height = (in_height + pad.top + pad.bottom + 1 - filter_height); + const size_t out_width = in_width + pad.left + pad.right + 1 - filter_width; + const size_t out_size = out_height * out_width * batches * channels; + + // We accumulate in FP32 and clamp and cast to return type later. + std::vector acc(out_size, 0.0f); + + for (size_t b = 0; b < batches; ++b) { + for (size_t out_h = 0; out_h < out_height; ++out_h) { + for (size_t out_w = 0; out_w < out_width; ++out_w) { + const size_t out_base = ((b * out_height + out_h) * out_width + out_w) * channels; + + // Apply filter to feature map. + for (size_t ic = 0; ic < channels; ++ic) { + float sum = 0.0f; + + for (size_t kernel_h = 0; kernel_h < filter_height; ++kernel_h) { + // Determine if input height bounds. If not, then this is padding. + const int in_y = static_cast(out_h + kernel_h) - static_cast(pad.top); + if (in_y < 0 || in_height <= static_cast(in_y)) continue; + + for (size_t kernel_w = 0; kernel_w < filter_width; ++kernel_w) { + // Determine if in input width bounds, if not this is padding. + const int in_x = static_cast(out_w + kernel_w) - static_cast(pad.left); + if (in_x < 0 || in_width <= static_cast(in_x)) continue; + + auto in_idx = ((b * in_height + in_y) * in_width + in_x) * channels + ic; + auto weights_idx = ((kernel_h * filter_width) + kernel_w) * channels + ic; + + auto wei_value = reinterpret_cast(weights)[weights_idx]; + auto in_value = reinterpret_cast(feature_map)[in_idx]; + + // Perform actual accumulation and store in output vector + sum += in_value * wei_value; + } + } + + auto out_idx = out_base + ic; + float bias_value = reinterpret_cast(bias)[ic]; + sum = sum + bias_value; + sum = std::clamp(sum, clamp_min, clamp_max); + reinterpret_cast(out)[out_idx] = sum; + } + } + } + } +} +} // namespace + +int main() { + const int batches = 1; + enum class pad_mode { SAME, VALID }; + + size_t total_test = 0; + for (pad_mode pad : {pad_mode::SAME, pad_mode::VALID}) { + for (size_t width = 128; width < 129; width += 2) { + for (size_t height = 141; height < 142; height += 2) { + for (size_t channels = 1; channels < 64; channels += 7) { + total_test++; + const int filter_height = 3; + const int filter_width = 3; + const int depth_multiplier = 1; // Only dm =1 supported. + + assert(filter_height > 1 && filter_width > 1); + + const size_t pad_total_height = (pad == pad_mode::SAME) ? filter_height - 1 : 0; + const size_t pad_total_width = (pad == pad_mode::SAME) ? filter_width - 1 : 0; + Padding2D padding; + padding.top = pad_total_height / 2; + padding.left = pad_total_width / 2; + padding.right = pad_total_width - padding.left; + padding.bottom = pad_total_height - padding.top; + + Shape in_shape{batches, height, width, channels}; + Shape wei_shape{filter_height, filter_width, channels, depth_multiplier}; + Shape bias_shape{depth_multiplier * channels}; + Shape out_shape{ + batches, (height + padding.top + padding.bottom + 1 - filter_height), + (width + padding.left + padding.right + 1 - filter_width), channels * depth_multiplier}; + + VEC_TYPE input(in_shape.size(), 0.0f); + VEC_TYPE weights(wei_shape.size(), 0.1f); + VEC_TYPE bias(bias_shape.size(), 0.0f); + VEC_TYPE out(out_shape.size(), 0.0f); + VEC_TYPE ref(out_shape.size(), 0.0f); + + fill_matrix(in_shape.size(), input, 0.01f); + fill_matrix(wei_shape.size(), weights, 0.02f); + fill_matrix_uniform(bias_shape.size(), bias, 1.f); + + // For testing using Python. +#ifdef KAI_DEBUG + { + std::cout << "\n#BEGIN PARAMS\n"; + std::cout << "\nbatch, height, width, channels = " << batches << ", " << height << ", " << width + << ", " << channels << std::endl; + std::cout << "\nfilter_height, filter_width = " << filter_height << ", " << filter_width + << std::endl; + print_raw(in_shape, "Inputs ", input); + print_raw(wei_shape, "Weights ", weights); + print_raw(bias_shape, "Bias ", bias); + std::cout << "\npad_top, pad_bottom = " << padding.top << ", " << padding.bottom << std::endl; + std::cout << "\npad_left, pad_right = " << padding.left << ", " << padding.right << std::endl + << std::endl; + std::cout << "\n#END PARAMS\n"; + } +#endif // KAI_DEBUG + + // ------------------------------------------------- + // 1. Calculate Reference Depthwise Values. + // ------------------------------------------------- + depthwise_reference( + batches, height, width, channels, filter_height, filter_width, (const void*)input.data(), + (const void*)weights.data(), (const void*)bias.data(), (void*)ref.data(), clamp_min, clamp_max, + padding); + + // ------------------------------------------------- + // 2. Pack weights for use in SME Kernel + // ------------------------------------------------- + // const size_t vec_length = kai_get_sme_vector_length_u32(); + const size_t packed_size = + kai_rhs_get_dst_size_dw_conv_pack_x32p1vl_x32_sme(filter_height, filter_width, channels) / + sizeof(float); + + // Run packing kernel. + std::vector weights_packed(packed_size); + kai_run_rhs_dw_conv_pack_x32p1vl_x32_sme( + weights.data(), weights_packed.data(), filter_height, filter_width, wei_shape[0], wei_shape[1], + channels); + +#ifdef KAI_DEBUG + const size_t vec_length = kai_get_sme_vector_length_u32(); + // Print packed weights - 1VL per row. + print_tensor( + {1, (weights_packed.size() / vec_length), 1, vec_length}, + "\n Weights Packed : ", weights_packed.data()); +#endif + // ------------------------------------------------- + // 3. Kernel takes in 6 rows of input and generates + // rows of output across all channels at a time. + // ------------------------------------------------- + constexpr size_t rows_handled = 4; // no of rows kernel handles each time. + for (size_t out_row = 0; out_row < out_shape.h; out_row += rows_handled) { + // Variables below used to calculate start of input pointer. + const int start_in_row = out_row - padding.top; + const size_t pad_top = (start_in_row < 0) ? (-start_in_row) : 0; + const size_t in_row = (start_in_row < 0) ? 0 : start_in_row; + + // Calculate row strides for pointer. + const size_t in_row_stride_elements = (width * channels); + const size_t out_row_stride_elements = (out_shape.w * out_shape.c); + + // Number of input rows that can be read, number of output rows to calculate. + const size_t valid_input_rows = (in_row < height) ? (height - in_row) : 0; + const size_t valid_out_rows = (out_shape.h - out_row); + + // Increment output/input pointers according to tile being calculated. + const auto inptr = input.data() + (in_row * in_row_stride_elements); + auto outptr = out.data() + (out_row * out_row_stride_elements); + + // NOTE: Kernel expects strides to be passed as bytes. + // f32_f32_f32p1vl -> f32 output, f32 LHS, packed F32 rhs as 1VL blocks. + // 3x3_s : 3x3 filter with stride 1 + // 4xc : 4 output channels across the plane(c) is produced. + kai_run_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla( + inptr, in_row_stride_elements * sizeof(float), channels * sizeof(float), pad_top, + padding.left, valid_input_rows, weights_packed.data(), bias.data(), outptr, + out_shape.c * sizeof(float), out_row_stride_elements * sizeof(float), valid_out_rows, + clamp_min, clamp_max, 0.0f); + } + +#ifdef KAI_DEBUG + // Print outputs + print_tensor(out_shape, "Reference : ", reinterpret_cast(ref.data())); + print_tensor(out_shape, "\n\n Actual : ", out.data()); + std::cout << "\n\nOut shape : " << out_shape << std::endl; +#endif // KAI_DEBUG + + /// Check for mismatches in the tests. + size_t mismatches = 0; + for (size_t i = 0; i < out_shape.size(); i++) { + float ref_value = ref[i]; + // FP32 rel tolerance - allows deviations of up to 0.05% + const auto err = (std::abs(out[i] - ref_value) / std::abs(ref_value)); + if (err > 0.0005) { + std::cout << "Mismatches(Expected:Actual)" << ref_value << " : " << out[i] << std::endl; + mismatches++; + } + if (mismatches > 0) { + std::cout << "\nNumber of mismatches: " << mismatches << std::endl; + } + } + } + } + } + } + std::cout << "total tests run: " << total_test << std::endl; +} diff --git a/kai/ukernels/dw_conv/BUILD.bazel b/kai/ukernels/dw_conv/BUILD.bazel new file mode 100644 index 0000000000000000000000000000000000000000..baeb21f01a3bfee95fd32cd8bfe1cb79bb0d954c --- /dev/null +++ b/kai/ukernels/dw_conv/BUILD.bazel @@ -0,0 +1,52 @@ +# +# SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates +# +# SPDX-License-Identifier: Apache-2.0 +# + +load( + "//:kai_defs.bzl", + "kai_c_library", + "kai_cpu_sme2", +) + +package(default_visibility = ["//visibility:private"]) + +# buildifier: keep sorted +SME2_KERNELS = [ + "pack/kai_rhs_dw_conv_pack_x32p1vl_x32_sme", +] + +SME2_KERNELS_ASM = [ + "depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla", +] + +kai_c_library( + name = "sme2_impl", + srcs = [ukernel + ".c" for ukernel in SME2_KERNELS], + cpu_uarch = kai_cpu_sme2(), + textual_hdrs = [ukernel + ".h" for ukernel in SME2_KERNELS], +) + +kai_c_library( + name = "sme2_impl_asm", + srcs = [ukernel + "_asm.S" for ukernel in SME2_KERNELS_ASM] + [ukernel + ".c" for ukernel in SME2_KERNELS_ASM], + cpu_uarch = kai_cpu_sme2(), + textual_hdrs = [ukernel + ".h" for ukernel in SME2_KERNELS_ASM], +) + +kai_c_library( + name = "interface", + textual_hdrs = glob(["**/*_interface.h"]), + visibility = ["//visibility:public"], +) + +kai_c_library( + name = "dw_conv", + visibility = ["//visibility:public"], + deps = [ + ":interface", + ":sme2_impl", + ":sme2_impl_asm", + ], +) diff --git a/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.c b/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.c new file mode 100644 index 0000000000000000000000000000000000000000..9abbc1dfd1ceb11f2a6e5c56dc1f104388836b24 --- /dev/null +++ b/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.c @@ -0,0 +1,133 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural features check. + +#include "kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.h" + +#include +#include + +#include "kai/kai_common.h" + +// Number of rows iterated through each call. +static const size_t kai_mr = 4; +// Filter/Kernel height and width +static const size_t kai_kh = 3; +static const size_t kai_kw = 3; +static const size_t kai_kr = 1; + +typedef struct { + const void* inptr; + size_t pad_top; + size_t pad_bottom; + const void* weights; + size_t input_cols; + size_t output_cols; + void** outptrs; + const void* ld_out_cols; + size_t ld_in_vl; + const void* ld_out_vls; + float clamp_min; + float clamp_max; + size_t pad_left; + const void* bias; + size_t current_channel; + size_t n_channels; +} KernelArgs; + +void kai_kernel_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla( + const KernelArgs* args, size_t ld_in_row, size_t ld_in_col); + +size_t kai_get_m_step_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla(void) { + return kai_mr; +} + +size_t kai_get_filter_height_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla(void) { + return kai_kh; +} + +size_t kai_get_filter_width_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla(void) { + return kai_kw; +} + +size_t kai_get_kr_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla(void) { + return kai_kr; +} + +size_t kai_get_dst_size_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla( + const size_t out_height, const size_t out_width, const size_t num_channels) { + return out_height * out_width * num_channels * sizeof(float); +} + +size_t kai_get_dst_offset_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla( + size_t out_row_idx, size_t stride_out_row) { + KAI_ASSUME(out_row_idx % kai_mr == 0); + return (out_row_idx * stride_out_row); +} + +void kai_run_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla( + const void* inptr, size_t stride_in_row, size_t stride_in_col, unsigned int pad_top, unsigned int pad_left, + unsigned int valid_input_rows, const void* weights, const void* bias, void* outptr_start, size_t stride_out_col, + size_t stride_out_row, unsigned int valid_out_rows, float act_min, float act_max, const float pad_value) { + KAI_ASSUME(inptr != NULL); + KAI_ASSUME(weights != NULL); + KAI_ASSUME(outptr_start != NULL); + KAI_ASSUME(valid_out_rows != 0); + KAI_ASSUME(pad_value == 0.0F); + + // Create padding row. + float pad_row[KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(float)] = {0}; + + // Calculate bottom padding offset. + long unsigned int pad_bottom = (6U < (pad_top + valid_input_rows)) ? (0U) : (6U - (pad_top + valid_input_rows)); + + // Leading dims calculated using input parameters. + size_t ld_in_vl = kai_get_sme_vector_length_u32(); + size_t ld_in_row = stride_in_row / sizeof(float); + size_t ld_in_col = stride_in_col / sizeof(float); + + // Calculate matrix dimensions using the strides provided. + const size_t num_channels = stride_out_col / sizeof(float); + const size_t output_cols = stride_out_row / (sizeof(float) * num_channels); + const size_t valid_input_cols = stride_in_row / (sizeof(float) * num_channels); + + // These arrays are initilised as if they were invalid/padded rows, then set if out row is valid + void* outptrs[4] = {pad_row, pad_row, pad_row, pad_row}; + size_t outlds[4] = {0}; + size_t outvllds[4] = {0}; + + for (unsigned int i = 0; i < 4; i++) { + if (i < valid_out_rows) { + outptrs[i] = (uint8_t*)outptr_start + (i * stride_out_row); + outlds[i] = num_channels; + outvllds[i] = ld_in_vl; + } + } + + KernelArgs args; + args.inptr = inptr; + args.ld_in_vl = ld_in_vl; + args.pad_top = pad_top; + args.pad_bottom = pad_bottom; + args.pad_left = pad_left; + args.weights = weights; + args.bias = bias; + args.input_cols = valid_input_cols; + args.output_cols = output_cols; + args.outptrs = outptrs; + args.ld_out_cols = outlds; + args.ld_out_vls = outvllds; + args.current_channel = 0; + args.n_channels = num_channels; + args.clamp_min = act_min; + args.clamp_max = act_max; + + kai_kernel_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla(&args, ld_in_row, ld_in_col); +} +#endif // Architectural features check. diff --git a/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.h b/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.h new file mode 100644 index 0000000000000000000000000000000000000000..920cd2f6d3a9d394b93116b3f84dc66d336b642c --- /dev/null +++ b/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.h @@ -0,0 +1,62 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +#include + +/// @return Maximum number of rows of output data produced by this kernel. +size_t kai_get_m_step_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla(void); + +/// @return Height of the filter used by this kernel. +size_t kai_get_filter_height_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla(void); + +/// @return Height of the filter used by this kernel. +size_t kai_get_filter_width_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla(void); + +/// This is the packing parameter which must be used to pack the LHS and RHS matrix. +/// @return The kr value. +size_t kai_get_kr_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla(void); + +/// @return output size in bytes. +size_t kai_get_dst_size_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla( + const size_t out_height, const size_t out_width, const size_t num_channels); + +/// @param[in] out_row_idx the row index of the output matrix +/// @param[in] stride_out_row Output row stride in bytes +/// @return offset to element in output/destination matrix. +size_t kai_get_dst_offset_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla(size_t out_row_idx, size_t stride_out_row); + +/// Runs a depthwise convolution operation followed by a clamp operations +/// +/// @param[in] inptr Pointer to the start of valid input row to be processed. +/// @param[in] stride_in_row Row stride of input tensor in bytes. +/// Same as input_w * input_channel when row_dilation = 1 +/// @param[in] stride_in_col Column stride within the input tensor, in bytes. +/// @param[in] pad_top Number of zero pad rows that precede the first valid input row. +/// @param[in] pad_left Number of zero pad columns on the left edge. +/// @param[in] valid_input_rows Count of real input rows available from the start row (identifies bottom padding). +/// @param[in] weights Pointer to packed weights. +/// @param[in] bias Optional pointer to bias array (one float per channel). +/// @param[in] outptr_start Pointer to the first element of the top output row for this tile (four rows written) +/// @param[in] stride_out_col Output column stride in bytes. +/// @param[in] stride_out_row Output row stride in bytes. +/// @param[in] valid_out_rows Number of rows to output to (1-4). +/// @param[in] act_min Lower clamp bound applied to every output value. +/// @param[in] act_max Upper clamp bound applied to every output value. +/// @param[in] pad_value Fill value for padding. This kernel only supports 0. +void kai_run_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla( + const void* inptr, size_t stride_in_row, size_t stride_in_col, unsigned int pad_top, unsigned int pad_left, + unsigned int valid_input_rows, const void* weights, const void* bias, void* outptr_start, size_t stride_out_col, + size_t stride_out_row, unsigned int valid_out_rows, float act_min, float act_max, const float pad_value); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla_asm.S b/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla_asm.S new file mode 100644 index 0000000000000000000000000000000000000000..a6276c85537cb61b0fd645e1810497c6a4841045 --- /dev/null +++ b/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla_asm.S @@ -0,0 +1,438 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla) +KAI_ASM_FUNCTION_LABEL(kai_kernel_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla) + stp x20, x21, [sp, -144]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + stp d8, d9, [sp, 72] + stp d10, d11, [sp, 88] + stp d12, d13, [sp, 104] + stp d14, d15, [sp, 120] + mov x20, #0x6 + KAI_ASM_INST(0xd503477f) // SMSTART ZA + ldr x7, [x0, #0x10] + ptrue p2.b + KAI_ASM_INST(0x25207812) // ptrue pn10.b + ldr x17, [x0, #0x8] + ld1rw { z3.s }, p2/Z, [x0, #80] + ldr x16, [x0, #0x70] + ld1rw { z9.s }, p2/Z, [x0, #84] + sub x20, x20, x7 + ldr x15, [x0, #0x68] + whilelt p1.s, XZR, x16 + whilelt p9.b, XZR, x20 + whilelt p8.b, XZR, x17 + eor p8.b, p2/Z, p8.b, p9.b +KAI_ASM_LABEL(label_1) // Channel loop + ldr x20, [x0, #0x60] + fmov z16.s, #0x0 + cbz x20, label_2 + ld1w { z16.s }, p1/Z, [x20, x15, LSL #2] +KAI_ASM_LABEL(label_2) // Load bias: Done + ldr x14, [x0, #0x20] + mov x23, #0x6 + add x20, x17, x7 + mov z17.d, z16.d + ldr x22, [x0, #0x18] + lsl x21, x1, #0x2 + mov z18.d, z16.d + mov z19.d, z16.d + ldr x13, [x0, #0x0] + mov x8, #0x0 + sub x23, x23, x20 + sub x20, x14, #0x1 + ldr x11, [x0, #0x28] + KAI_ASM_INST(0xa0404ace) // ld1w { z14.s-z15.s }, pn10.b/Z, [x22] + orr x20, x20, x2, LSL #18 + ld1w { z11.s }, p2/Z, [x22, #2, MUL VL] + addvl x22, x22, #3 + orr x20, x16, x20, LSL #20 + KAI_ASM_INST(0xa0404acc) // ld1w { z12.s-z13.s }, pn10.b/Z, [x22] + lsl x20, x20, #0x2 + madd x21, x21, x17, x13 + ld1w { z0.s }, p2/Z, [x22, #2, MUL VL] + addvl x22, x22, #3 + KAI_ASM_INST(0xa0404ac4) // ld1w { z4.s-z5.s }, pn10.b/Z, [x22] + ld1w { z7.s }, p2/Z, [x22, #2, MUL VL] +KAI_ASM_LABEL(label_3) // Issue prefetches + subs x23, x23, #0x1 + KAI_ASM_INST(0xf8b44abc) // rprfm pldstrm, x20, [x21] + add x21, x21, x2, LSL #2 + bgt label_3 + ldr x22, [x0, #0x30] + lsl x21, x1, #0x2 + KAI_ASM_INST(0xc0040e00) // mova za.d[x8, #0], { z16.d-z19.d } + mov x10, #0x2 + ldr x20, [x0, #0x38] + msub x13, x17, x21, x13 + KAI_ASM_INST(0xc0040e01) // mova za.d[x8, #1], { z16.d-z19.d } + ldr x21, [x0, #0x58] + KAI_ASM_INST(0xc0040e02) // mova za.d[x8, #2], { z16.d-z19.d } + ldp x9, x28, [x22], #0x10 + ldp x27, x26, [x20], #0x10 + ldp x25, x24, [x22], #0x10 + ldp x23, x22, [x20], #0x10 + cbz x21, label_5 + cmp x21, x10 + csel x20, x21, x10, LT + sub x21, x21, x20 + sub x10, x10, x20 + cbz x21, label_5 + KAI_ASM_INST(0xc0060c1c) // mova { z28.d-z31.d }, za.d[x8, #0] + sub x11, x11, x21 + KAI_ASM_INST(0xc1a9c87c) // fclamp { z28.s-z31.s }, z3.s, z9.s +KAI_ASM_LABEL(label_4) // Left padding + subs x21, x21, #0x1 + st1w { z28.s }, p1, [x9] + add x9, x9, x27, LSL #2 + st1w { z29.s }, p1, [x28] + add x28, x28, x26, LSL #2 + st1w { z30.s }, p1, [x25] + add x25, x25, x23, LSL #2 + st1w { z31.s }, p1, [x24] + add x24, x24, x22, LSL #2 + bgt label_4 +KAI_ASM_LABEL(label_5) // Left padding: End + adds XZR, x17, x7 + bne label_10 + cbz x10, label_8 + cmp x10, #0x1 + sub x14, x14, x10 + beq label_7 + add x20, x13, x1, LSL #2 + ld1w { z22.s }, p1/Z, [x13] + add x13, x13, x2, LSL #2 + ld1w { z23.s }, p1/Z, [x20] + add x20, x20, x1, LSL #2 + ld1w { z24.s }, p1/Z, [x20] + add x20, x20, x1, LSL #2 + ld1w { z25.s }, p1/Z, [x20] + add x20, x20, x1, LSL #2 + ld1w { z26.s }, p1/Z, [x20] + add x20, x20, x1, LSL #2 + ld1w { z27.s }, p1/Z, [x20] + KAI_ASM_INST(0xc13e1ac0) // fmla za.s[x8, 0], { z22.s-z25.s }, z14.s + KAI_ASM_INST(0xc13c1ae0) // fmla za.s[x8, 0], { z23.s-z26.s }, z12.s + KAI_ASM_INST(0xc1341b00) // fmla za.s[x8, 0], { z24.s-z27.s }, z4.s +KAI_ASM_LABEL(label_7) // Unpadded: 1 priming loads + add x20, x13, x1, LSL #2 + ld1w { z24.s }, p1/Z, [x13] + add x13, x13, x2, LSL #2 + ld1w { z25.s }, p1/Z, [x20] + add x20, x20, x1, LSL #2 + ld1w { z26.s }, p1/Z, [x20] + add x20, x20, x1, LSL #2 + ld1w { z27.s }, p1/Z, [x20] + add x20, x20, x1, LSL #2 + ld1w { z28.s }, p1/Z, [x20] + add x20, x20, x1, LSL #2 + ld1w { z29.s }, p1/Z, [x20] + KAI_ASM_INST(0xc13f1b00) // fmla za.s[x8, 0], { z24.s-z27.s }, z15.s + KAI_ASM_INST(0xc13e1b01) // fmla za.s[x8, 1], { z24.s-z27.s }, z14.s + KAI_ASM_INST(0xc13d1b20) // fmla za.s[x8, 0], { z25.s-z28.s }, z13.s + KAI_ASM_INST(0xc13c1b21) // fmla za.s[x8, 1], { z25.s-z28.s }, z12.s + KAI_ASM_INST(0xc1351b40) // fmla za.s[x8, 0], { z26.s-z29.s }, z5.s + KAI_ASM_INST(0xc1341b41) // fmla za.s[x8, 1], { z26.s-z29.s }, z4.s +KAI_ASM_LABEL(label_8) // Unpadded: 0 priming loads + cbz x14, label_16 + add x20, x13, x1, LSL #2 + ld1w { z20.s }, p1/Z, [x13] + sub x14, x14, #0x1 + ld1w { z21.s }, p1/Z, [x20] + add x20, x20, x1, LSL #2 + sub x11, x11, #0x1 + ld1w { z22.s }, p1/Z, [x20] + add x20, x20, x1, LSL #2 + cmp x14, x11 + ld1w { z23.s }, p1/Z, [x20] + add x20, x20, x1, LSL #2 + csel x21, x14, x11, LT + ld1w { z24.s }, p1/Z, [x20] + add x20, x20, x1, LSL #2 + add x13, x13, x2, LSL #2 + ld1w { z25.s }, p1/Z, [x20] + sub x11, x11, x21 + cbz x21, label_15 +KAI_ASM_LABEL(label_9) // Unpadded: Main loop + KAI_ASM_INST(0xc13b1a80) // fmla za.s[x8, 0], { z20.s-z23.s }, z11.s + add x20, x13, x1, LSL #2 + subs x21, x21, #0x1 + KAI_ASM_INST(0xc13f1a81) // fmla za.s[x8, 1], { z20.s-z23.s }, z15.s + KAI_ASM_INST(0xc13e1a82) // fmla za.s[x8, 2], { z20.s-z23.s }, z14.s + ld1w { z20.s }, p1/Z, [x13] + add x13, x13, x2, LSL #2 + KAI_ASM_INST(0xc1301aa0) // fmla za.s[x8, 0], { z21.s-z24.s }, z0.s + KAI_ASM_INST(0xc13d1aa1) // fmla za.s[x8, 1], { z21.s-z24.s }, z13.s + KAI_ASM_INST(0xc13c1aa2) // fmla za.s[x8, 2], { z21.s-z24.s }, z12.s + ld1w { z21.s }, p1/Z, [x20] + add x20, x20, x1, LSL #2 + KAI_ASM_INST(0xc1371ac0) // fmla za.s[x8, 0], { z22.s-z25.s }, z7.s + KAI_ASM_INST(0xc1351ac1) // fmla za.s[x8, 1], { z22.s-z25.s }, z5.s + KAI_ASM_INST(0xc1341ac2) // fmla za.s[x8, 2], { z22.s-z25.s }, z4.s + ld1w { z22.s }, p1/Z, [x20] + add x20, x20, x1, LSL #2 + ld1w { z23.s }, p1/Z, [x20] + add x20, x20, x1, LSL #2 + KAI_ASM_INST(0xc0060c1c) // mova { z28.d-z31.d }, za.d[x8, #0] + add x8, x8, #0x1 + ld1w { z24.s }, p1/Z, [x20] + add x20, x20, x1, LSL #2 + KAI_ASM_INST(0xc0040e02) // mova za.d[x8, #2], { z16.d-z19.d } + ld1w { z25.s }, p1/Z, [x20] + KAI_ASM_INST(0xc1a9c87c) // fclamp { z28.s-z31.s }, z3.s, z9.s + st1w { z28.s }, p1, [x9] + add x9, x9, x27, LSL #2 + st1w { z29.s }, p1, [x28] + add x28, x28, x26, LSL #2 + st1w { z30.s }, p1, [x25] + add x25, x25, x23, LSL #2 + st1w { z31.s }, p1, [x24] + add x24, x24, x22, LSL #2 + bgt label_9 + b label_15 +KAI_ASM_LABEL(label_10) // Padded + cbz x10, label_13 + cmp x10, #0x1 + sub x14, x14, x10 + beq label_12 + mov x12, #0x0 + add x20, x13, x1, LSL #2 + KAI_ASM_INST(0x25244500) // psel p0.b, p1.b/Z, p8.b[w12] + ld1w { z23.s }, p0/Z, [x13] + KAI_ASM_INST(0x252c4500) // psel p0.b, p1.b/Z, p8.b[w12, #1] + add x13, x13, x2, LSL #2 + ld1w { z24.s }, p0/Z, [x20] + KAI_ASM_INST(0x25344500) // psel p0.b, p1.b/Z, p8.b[w12, #2] + add x20, x20, x1, LSL #2 + ld1w { z25.s }, p0/Z, [x20] + KAI_ASM_INST(0x253c4500) // psel p0.b, p1.b/Z, p8.b[w12, #3] + add x20, x20, x1, LSL #2 + mov x12, #0x4 + ld1w { z26.s }, p0/Z, [x20] + KAI_ASM_INST(0x25244500) // psel p0.b, p1.b/Z, p8.b[w12] + add x20, x20, x1, LSL #2 + KAI_ASM_INST(0xc13e1ae0) // fmla za.s[x8, 0], { z23.s-z26.s }, z14.s + ld1w { z27.s }, p0/Z, [x20] + KAI_ASM_INST(0x252c4500) // psel p0.b, p1.b/Z, p8.b[w12, #1] + add x20, x20, x1, LSL #2 + ld1w { z28.s }, p0/Z, [x20] + KAI_ASM_INST(0xc13c1b00) // fmla za.s[x8, 0], { z24.s-z27.s }, z12.s + KAI_ASM_INST(0xc1341b20) // fmla za.s[x8, 0], { z25.s-z28.s }, z4.s +KAI_ASM_LABEL(label_12) // Padded: 1 priming loads + mov x12, #0x0 + add x20, x13, x1, LSL #2 + KAI_ASM_INST(0x25244500) // psel p0.b, p1.b/Z, p8.b[w12] + ld1w { z25.s }, p0/Z, [x13] + KAI_ASM_INST(0x252c4500) // psel p0.b, p1.b/Z, p8.b[w12, #1] + add x13, x13, x2, LSL #2 + ld1w { z26.s }, p0/Z, [x20] + KAI_ASM_INST(0x25344500) // psel p0.b, p1.b/Z, p8.b[w12, #2] + add x20, x20, x1, LSL #2 + ld1w { z27.s }, p0/Z, [x20] + KAI_ASM_INST(0x253c4500) // psel p0.b, p1.b/Z, p8.b[w12, #3] + add x20, x20, x1, LSL #2 + mov x12, #0x4 + ld1w { z28.s }, p0/Z, [x20] + KAI_ASM_INST(0x25244500) // psel p0.b, p1.b/Z, p8.b[w12] + add x20, x20, x1, LSL #2 + KAI_ASM_INST(0xc13f1b20) // fmla za.s[x8, 0], { z25.s-z28.s }, z15.s + ld1w { z29.s }, p0/Z, [x20] + KAI_ASM_INST(0x252c4500) // psel p0.b, p1.b/Z, p8.b[w12, #1] + add x20, x20, x1, LSL #2 + KAI_ASM_INST(0xc13e1b21) // fmla za.s[x8, 1], { z25.s-z28.s }, z14.s + ld1w { z30.s }, p0/Z, [x20] + KAI_ASM_INST(0xc13d1b40) // fmla za.s[x8, 0], { z26.s-z29.s }, z13.s + KAI_ASM_INST(0xc13c1b41) // fmla za.s[x8, 1], { z26.s-z29.s }, z12.s + KAI_ASM_INST(0xc1351b60) // fmla za.s[x8, 0], { z27.s-z30.s }, z5.s + KAI_ASM_INST(0xc1341b61) // fmla za.s[x8, 1], { z27.s-z30.s }, z4.s +KAI_ASM_LABEL(label_13) // Padded: 0 priming loads + cbz x14, label_16 + mov x12, #0x0 + add x20, x13, x1, LSL #2 + KAI_ASM_INST(0x25244500) // psel p0.b, p1.b/Z, p8.b[w12] + sub x14, x14, #0x1 + sub x11, x11, #0x1 + cmp x14, x11 + ld1w { z20.s }, p0/Z, [x13] + KAI_ASM_INST(0x252c4500) // psel p0.b, p1.b/Z, p8.b[w12, #1] + csel x21, x14, x11, LT + add x13, x13, x2, LSL #2 + sub x11, x11, x21 + ld1w { z21.s }, p0/Z, [x20] + KAI_ASM_INST(0x25344500) // psel p0.b, p1.b/Z, p8.b[w12, #2] + add x20, x20, x1, LSL #2 + ld1w { z22.s }, p0/Z, [x20] + KAI_ASM_INST(0x253c4500) // psel p0.b, p1.b/Z, p8.b[w12, #3] + add x20, x20, x1, LSL #2 + mov x12, #0x4 + ld1w { z23.s }, p0/Z, [x20] + KAI_ASM_INST(0x25244500) // psel p0.b, p1.b/Z, p8.b[w12] + add x20, x20, x1, LSL #2 + ld1w { z24.s }, p0/Z, [x20] + KAI_ASM_INST(0x252c4500) // psel p0.b, p1.b/Z, p8.b[w12, #1] + add x20, x20, x1, LSL #2 + ld1w { z25.s }, p0/Z, [x20] + cbz x21, label_15 +KAI_ASM_LABEL(label_14) // Padded: Main loop + mov x12, #0x0 + KAI_ASM_INST(0xc13b1a80) // fmla za.s[x8, 0], { z20.s-z23.s }, z11.s + add x20, x13, x1, LSL #2 + KAI_ASM_INST(0x25244500) // psel p0.b, p1.b/Z, p8.b[w12] + KAI_ASM_INST(0xc13f1a81) // fmla za.s[x8, 1], { z20.s-z23.s }, z15.s + subs x21, x21, #0x1 + KAI_ASM_INST(0xc13e1a82) // fmla za.s[x8, 2], { z20.s-z23.s }, z14.s + ld1w { z20.s }, p0/Z, [x13] + KAI_ASM_INST(0x252c4500) // psel p0.b, p1.b/Z, p8.b[w12, #1] + KAI_ASM_INST(0xc1301aa0) // fmla za.s[x8, 0], { z21.s-z24.s }, z0.s + add x13, x13, x2, LSL #2 + KAI_ASM_INST(0xc13d1aa1) // fmla za.s[x8, 1], { z21.s-z24.s }, z13.s + KAI_ASM_INST(0xc13c1aa2) // fmla za.s[x8, 2], { z21.s-z24.s }, z12.s + ld1w { z21.s }, p0/Z, [x20] + KAI_ASM_INST(0x25344500) // psel p0.b, p1.b/Z, p8.b[w12, #2] + add x20, x20, x1, LSL #2 + KAI_ASM_INST(0xc1371ac0) // fmla za.s[x8, 0], { z22.s-z25.s }, z7.s + KAI_ASM_INST(0xc1351ac1) // fmla za.s[x8, 1], { z22.s-z25.s }, z5.s + KAI_ASM_INST(0xc1341ac2) // fmla za.s[x8, 2], { z22.s-z25.s }, z4.s + ld1w { z22.s }, p0/Z, [x20] + KAI_ASM_INST(0x253c4500) // psel p0.b, p1.b/Z, p8.b[w12, #3] + add x20, x20, x1, LSL #2 + mov x12, #0x4 + KAI_ASM_INST(0xc0060c1c) // mova { z28.d-z31.d }, za.d[x8, #0] + add x8, x8, #0x1 + ld1w { z23.s }, p0/Z, [x20] + KAI_ASM_INST(0x25244500) // psel p0.b, p1.b/Z, p8.b[w12] + add x20, x20, x1, LSL #2 + KAI_ASM_INST(0xc0040e02) // mova za.d[x8, #2], { z16.d-z19.d } + ld1w { z24.s }, p0/Z, [x20] + KAI_ASM_INST(0x252c4500) // psel p0.b, p1.b/Z, p8.b[w12, #1] + add x20, x20, x1, LSL #2 + KAI_ASM_INST(0xc1a9c87c) // fclamp { z28.s-z31.s }, z3.s, z9.s + ld1w { z25.s }, p0/Z, [x20] + st1w { z28.s }, p1, [x9] + add x9, x9, x27, LSL #2 + st1w { z29.s }, p1, [x28] + add x28, x28, x26, LSL #2 + st1w { z30.s }, p1, [x25] + add x25, x25, x23, LSL #2 + st1w { z31.s }, p1, [x24] + add x24, x24, x22, LSL #2 + bgt label_14 +KAI_ASM_LABEL(label_15) // Main loop tail + KAI_ASM_INST(0xc13b1a80) // fmla za.s[x8, 0], { z20.s-z23.s }, z11.s + KAI_ASM_INST(0xc13f1a81) // fmla za.s[x8, 1], { z20.s-z23.s }, z15.s + KAI_ASM_INST(0xc13e1a82) // fmla za.s[x8, 2], { z20.s-z23.s }, z14.s + KAI_ASM_INST(0xc1301aa0) // fmla za.s[x8, 0], { z21.s-z24.s }, z0.s + KAI_ASM_INST(0xc13d1aa1) // fmla za.s[x8, 1], { z21.s-z24.s }, z13.s + KAI_ASM_INST(0xc13c1aa2) // fmla za.s[x8, 2], { z21.s-z24.s }, z12.s + KAI_ASM_INST(0xc1371ac0) // fmla za.s[x8, 0], { z22.s-z25.s }, z7.s + KAI_ASM_INST(0xc1351ac1) // fmla za.s[x8, 1], { z22.s-z25.s }, z5.s + KAI_ASM_INST(0xc1341ac2) // fmla za.s[x8, 2], { z22.s-z25.s }, z4.s + KAI_ASM_INST(0xc0060c14) // mova { z20.d-z23.d }, za.d[x8, #0] + add x8, x8, #0x1 + KAI_ASM_INST(0xc0040e02) // mova za.d[x8, #2], { z16.d-z19.d } + KAI_ASM_INST(0xc1a9c874) // fclamp { z20.s-z23.s }, z3.s, z9.s + st1w { z20.s }, p1, [x9] + add x9, x9, x27, LSL #2 + st1w { z21.s }, p1, [x28] + add x28, x28, x26, LSL #2 + st1w { z22.s }, p1, [x25] + add x25, x25, x23, LSL #2 + st1w { z23.s }, p1, [x24] + add x24, x24, x22, LSL #2 +KAI_ASM_LABEL(label_16) // Main loop skip tail + cbz x11, label_18 +KAI_ASM_LABEL(label_17) // Right padding loop + KAI_ASM_INST(0xc0060c04) // mova { z4.d-z7.d }, za.d[x8, #0] + add x8, x8, #0x1 + subs x11, x11, #0x1 + KAI_ASM_INST(0xc0040e02) // mova za.d[x8, #2], { z16.d-z19.d } + KAI_ASM_INST(0xc1a9c864) // fclamp { z4.s-z7.s }, z3.s, z9.s + st1w { z4.s }, p1, [x9] + add x9, x9, x27, LSL #2 + st1w { z5.s }, p1, [x28] + add x28, x28, x26, LSL #2 + st1w { z6.s }, p1, [x25] + add x25, x25, x23, LSL #2 + st1w { z7.s }, p1, [x24] + add x24, x24, x22, LSL #2 + bgt label_17 +KAI_ASM_LABEL(label_18) // End + ldr x20, [x0, #0x18] + incw x15 + whilelt p1.s, x15, x16 + incb x20, ALL, MUL #9 + str x20, [x0, #0x18] + ldr x21, [x0, #0x40] + ldr x20, [x0, #0x0] + add x20, x20, x21, LSL #2 + str x20, [x0, #0x0] + ldr x25, [x0, #0x30] + ldr x24, [x0, #0x48] + ldp x23, x22, [x25, #0x0] + ldp x21, x20, [x24, #0x0] + add x23, x23, x21, LSL #2 + add x22, x22, x20, LSL #2 + stp x23, x22, [x25, #0x0] + ldp x23, x22, [x25, #0x10] + ldp x21, x20, [x24, #0x10] + add x23, x23, x21, LSL #2 + add x22, x22, x20, LSL #2 + stp x23, x22, [x25, #0x10] + b.any label_1 + KAI_ASM_INST(0xd503467f) // SMSTOP + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp d8, d9, [sp, 72] + ldp d10, d11, [sp, 88] + ldp d12, d13, [sp, 104] + ldp d14, d15, [sp, 120] + ldp x20, x21, [sp], 144 + ret + KAI_ASM_FUNCTION_END(kai_kernel_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla) + + KAI_ASM_END diff --git a/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p_interface.h b/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p_interface.h new file mode 100644 index 0000000000000000000000000000000000000000..ba860d2605b202377a2a68770bd8fd83c74c0bb7 --- /dev/null +++ b/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p_interface.h @@ -0,0 +1,43 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// All micro-kernels variants of the same type share the same interfaces +// In this case, the micro-kernel type is: dw_conv_f32_f32_f32p_planar +// NOTE: +// - get_n_step is not provided as n-step is not relevant in planar kernels. +// - get_lhs_packed_offset is not provided as the lhs is not packed with planar kernels. +// - get_rhs_packed_offset is not provided as rhs offset is not relevant with planar kernels. + +/// Micro-kernel helper functions ("get" methods) +typedef size_t (*kai_dw_conv_f32_f32_f32p_planar_get_m_step_func_t)(void); +typedef size_t (*kai_dw_conv_f32_f32_f32p_planar_get_dst_offset_func_t)(size_t out_row_idx, size_t stride_out_row); +typedef size_t (*kai_dw_conv_f32_f32_f32p_planar_get_dst_size_func_t)( + size_t out_height, size_t out_width, size_t num_channels); + +/// Micro-kernel core function ("run" method) +typedef void (*kai_dw_conv_f32_f32_f32p_planar_run_dw_conv_func_t)( + const void* inptr, size_t stride_in_row, size_t stride_in_col, unsigned int pad_top, unsigned int pad_left, + unsigned int valid_input_rows, const void* weights, const void* bias, void* outptr_start, size_t stride_out_col, + size_t stride_out_row, unsigned int valid_out_rows, float act_min, float act_max, float pad_value); + +/// Micro-kernel interface +struct kai_dw_conv_f32_f32_f32p_planar_ukernel { + kai_dw_conv_f32_f32_f32p_planar_get_m_step_func_t get_m_step; + kai_dw_conv_f32_f32_f32p_planar_get_dst_offset_func_t get_dst_offset; + kai_dw_conv_f32_f32_f32p_planar_get_dst_size_func_t get_dst_size; + kai_dw_conv_f32_f32_f32p_planar_run_dw_conv_func_t run_dw_conv; +}; + +#ifdef __cplusplus +} +#endif diff --git a/kai/ukernels/dw_conv/pack/kai_rhs_dw_conv_pack_x32p1vl_x32_sme.c b/kai/ukernels/dw_conv/pack/kai_rhs_dw_conv_pack_x32p1vl_x32_sme.c new file mode 100644 index 0000000000000000000000000000000000000000..b870a8fa1cabb7422b48d26acdd1ab90d85c3478 --- /dev/null +++ b/kai/ukernels/dw_conv/pack/kai_rhs_dw_conv_pack_x32p1vl_x32_sme.c @@ -0,0 +1,44 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "kai/ukernels/dw_conv/pack/kai_rhs_dw_conv_pack_x32p1vl_x32_sme.h" + +#include +#include + +#include "kai/kai_common.h" + +size_t kai_rhs_get_dst_size_dw_conv_pack_x32p1vl_x32_sme( + size_t filter_height, size_t filter_width, size_t num_channels) { + const size_t depth_elements = kai_roundup(num_channels, kai_get_sme_vector_length_u32()); + return depth_elements * filter_height * filter_width * sizeof(float); +} + +void kai_run_rhs_dw_conv_pack_x32p1vl_x32_sme( + void* rhs, void* rhs_packed, size_t filter_height, size_t filter_width, size_t height, size_t width, + size_t num_channels) { + KAI_ASSUME(rhs != NULL); + KAI_ASSUME(rhs_packed != NULL); + KAI_UNUSED(height); + KAI_UNUSED(width); + + // Cast the pointers to byte sizes + const uint8_t* src = (uint8_t*)(rhs); + uint8_t* dst = (uint8_t*)(rhs_packed); + + const size_t vl = kai_get_sme_vector_length_u32(); + const size_t element_size = sizeof(float); + + for (size_t n = 0; n < num_channels; n += vl) { + // Copy each of the weights in turn + const size_t count = (vl < (num_channels - n)) ? vl : (num_channels - n); + for (size_t idx = 0; idx < filter_height * filter_width; idx++) { + const uint8_t* src_ptr = src + ((idx * num_channels + n) * element_size); + memcpy(dst, src_ptr, count * element_size); + dst += (vl * element_size); // move ptr. + } + } +} diff --git a/kai/ukernels/dw_conv/pack/kai_rhs_dw_conv_pack_x32p1vl_x32_sme.h b/kai/ukernels/dw_conv/pack/kai_rhs_dw_conv_pack_x32p1vl_x32_sme.h new file mode 100644 index 0000000000000000000000000000000000000000..fa3a04ea398ea27921b0e6e2cbeda0f19dc9a241 --- /dev/null +++ b/kai/ukernels/dw_conv/pack/kai_rhs_dw_conv_pack_x32p1vl_x32_sme.h @@ -0,0 +1,43 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Get the size in bytes of the packed data buffer. +/// @param[in] filter_height filter height of filter being used in convolution. +/// @param[in] filter_width filter width of filter being used in convolution. +/// @param[in] num_channels Number of channels in input matrix. +/// @return The size in bytes of packed data buffer. +size_t kai_rhs_get_dst_size_dw_conv_pack_x32p1vl_x32_sme( + size_t filter_height, size_t filter_width, size_t num_channels); + +/// Runs the packing function for the depthwise convolution kernel. +/// +/// NOTE: filter_height/filter_width is seperate from height/width of weights to allow for padding when using weights +/// shapes different to kernel conv filter size. These should be the same in typical usecases. +/// +/// NOTE: The API below is experimental and may change in the future, particularly the parameters "height" and "width" +/// +/// @param[in] rhs Rhs matrix data buffer +/// @param[out] rhs_packed Packed data matrix buffer +/// @param[in] filter_height filter height of filter being used in convolution. +/// @param[in] filter_width filter width of filter being used in convolution. +/// @param[in] height Height dimension of rhs matrix. Unused. (Typically equivalent to filter_height) +/// @param[in] width Width dimension of rhs matrix. Unused. (Typically equivalent to filter_width) +/// @param[in] num_channels Number of channels in input matrix. +void kai_run_rhs_dw_conv_pack_x32p1vl_x32_sme( + void* rhs, void* rhs_packed, size_t filter_height, size_t filter_width, size_t height, size_t width, + size_t num_channels); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/test/reference/depthwise_conv.cpp b/test/reference/depthwise_conv.cpp new file mode 100644 index 0000000000000000000000000000000000000000..845b9face57f0dfb8f45b438d557a6f74ef8c9f3 --- /dev/null +++ b/test/reference/depthwise_conv.cpp @@ -0,0 +1,79 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/reference/depthwise_conv.hpp" + +#include + +#include "kai/kai_common.h" + +namespace kai::test { + +void PrintTo(const Padding2D& pad, std::ostream* os) { + *os << "__PAD_" << pad.left << "_" << pad.right << "_" << pad.bottom << "_" << pad.top << "_"; +}; + +template +Buffer depthwise_reference( + const size_t batches, const size_t in_height, const size_t in_width, const size_t channels, + const size_t filter_height, const size_t filter_width, const void* feature_map, const void* weights, + const void* bias, const Padding2D& pad) { + // Calculate output dims (Padding = Valid). + const size_t out_height = (in_height + pad.top + pad.bottom + 1 - filter_height); + const size_t out_width = in_width + pad.left + pad.right + 1 - filter_width; + const size_t out_size = out_height * out_width * batches * channels; + + // We accumulate in FP32 and clamp and cast to return type later. + std::vector acc(out_size, 0.0f); + Buffer dst(out_size * size_in_bits / 8); + + for (size_t b = 0; b < batches; ++b) { + for (size_t out_h = 0; out_h < out_height; ++out_h) { + for (size_t out_w = 0; out_w < out_width; ++out_w) { + const size_t out_base = ((b * out_height + out_h) * out_width + out_w) * channels; + + // Apply filter to feature map. + for (size_t ic = 0; ic < channels; ++ic) { + float sum = 0.0f; + + for (size_t kernel_h = 0; kernel_h < filter_height; ++kernel_h) { + // Determine if input height bounds. If not, then this is padding. + const int in_y = static_cast(out_h + kernel_h) - static_cast(pad.top); + if (in_y < 0 || in_height <= static_cast(in_y)) continue; + + for (size_t kernel_w = 0; kernel_w < filter_width; ++kernel_w) { + // Determine if in input width bounds, if not this is padding. + const int in_x = static_cast(out_w + kernel_w) - static_cast(pad.left); + if (in_x < 0 || in_width <= static_cast(in_x)) continue; + + auto in_idx = ((b * in_height + in_y) * in_width + in_x) * channels + ic; + auto weights_idx = ((kernel_h * filter_width) + kernel_w) * channels + ic; + + auto wei_value = read_array(weights, weights_idx); + auto in_value = read_array(feature_map, in_idx); + + // Perform actual accumulation and store in output vector + sum += in_value * wei_value; + } + } + + auto out_idx = out_base + ic; + sum = sum + (T)read_array(bias, ic); + write_array(dst.data(), out_idx, sum); + } + } + } + } + return dst; +} + +// Explicit template +template Buffer depthwise_reference( + const size_t batches, const size_t in_height, const size_t in_width, const size_t channels, + const size_t filter_height, const size_t filter_width, const void* feature_map, const void* weights, + const void* bias, const Padding2D& pad); + +} // namespace kai::test diff --git a/test/reference/depthwise_conv.hpp b/test/reference/depthwise_conv.hpp new file mode 100644 index 0000000000000000000000000000000000000000..edbd9b16ce5bb55d2c118774d681751cabd86553 --- /dev/null +++ b/test/reference/depthwise_conv.hpp @@ -0,0 +1,66 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once +#include + +#include "test/common/buffer.hpp" +#include "test/common/data_type.hpp" +#include "test/common/memory.hpp" + +namespace kai::test { + +struct Padding2D { + size_t left; + size_t right; + size_t top; + size_t bottom; + + struct Hash { + size_t operator()(const Padding2D pad) const { + return // + (std::hash{}(pad.left) << 0) ^ // + (std::hash{}(pad.right) << 1) ^ // + (std::hash{}(pad.top) << 2) ^ // + (std::hash{}(pad.bottom) << 3); + } + }; + +private: + friend bool operator==(const Padding2D& lhs, const Padding2D& rhs) { + return // + lhs.left == rhs.left && lhs.right == rhs.right && lhs.top == rhs.top && lhs.bottom == rhs.bottom; + } + friend std::ostream& operator<<(std::ostream& os, const Padding2D& shape); +}; + +void PrintTo(const Padding2D& pad, std::ostream* os); + +/// Depthwise Convolution function +/// +/// @tparam T Data type. +/// +/// @param[in] batches Batch dimension of feature map. +/// @param[in] in_height height of feature map. +/// @param[in] in_width width of feature map. +/// @param[in] channels Number of channels in feature map. +/// @param[in] filter_height Height dimension in filter. +/// @param[in] filter_width Width of convolution filter. +/// @param[in] feature_map Ptr to start of feature map. +/// @param[in] weights Ptr to start of weights buffer/tensor. +/// @param[in] bias Ptr to start of bias buffer. +/// @param[in] clamp_min float value to clamp output to (lower bound). +/// @param[in] clamp_max float value to clamp output to (upper bound). +/// @param[in] pad Padding object. +/// +/// @return The result data buffer. +template +Buffer depthwise_reference( + const size_t batches, const size_t in_height, const size_t in_width, const size_t channels, + const size_t filter_height, const size_t filter_width, const void* feature_map, const void* weights, + const void* bias, const Padding2D& pad); + +} // namespace kai::test diff --git a/test/tests/depthwise_planar_test.cpp b/test/tests/depthwise_planar_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3164e51d47bbf379badfded903b36e182836d061 --- /dev/null +++ b/test/tests/depthwise_planar_test.cpp @@ -0,0 +1,304 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include +#include +#include +#include + +#include "kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.h" +#include "kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p_interface.h" +#include "kai/ukernels/dw_conv/pack/kai_rhs_dw_conv_pack_x32p1vl_x32_sme.h" +#include "test/common/buffer.hpp" +#include "test/common/compare.hpp" +#include "test/common/cpu_info.hpp" +#include "test/common/matmul_test_common.hpp" +#include "test/common/matrix_portion.hpp" +#include "test/common/memory.hpp" +#include "test/common/round.hpp" +#include "test/common/sme.hpp" +#include "test/reference/clamp.hpp" +#include "test/reference/depthwise_conv.hpp" +#include "test/reference/fill.hpp" + +namespace kai::test { + +namespace { + +/// Interface for depthwise kernel. +struct DepthwisePlanarKernel { + std::function get_dst_size; + std::function get_dst_offset; + std::function get_m_step; + std::function + conv; +}; + +// Rhs packing kernel. +struct RhsPackDepthwiseKernel { + std::function get_rhs_packed_size; + std::function + pack; +}; + +/// Description of a Depthwise kernel set +struct Depthwise { + std::string_view name; + std::function is_supported; + std::pair filter; + DataType data_type; + DataType acc_type; + RhsPackDepthwiseKernel rhs; + DepthwisePlanarKernel depthwise; +}; + +/// Convenience types for testing. +using DepthwiseArray = std::array; +using DepthwiseParamsParams = std::tuple; +using DepthwisePlanarTest = testing::TestWithParam; + +/// Use interface for depthwise kernel +const kai_dw_conv_f32_f32_f32p_planar_ukernel& get_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla() { + static kai_dw_conv_f32_f32_f32p_planar_ukernel ukernel; + ukernel.get_m_step = kai_get_m_step_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla; + ukernel.get_dst_offset = kai_get_dst_offset_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla; + ukernel.get_dst_size = kai_get_dst_size_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla; + ukernel.run_dw_conv = kai_run_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla; + return ukernel; +} + +const DepthwiseArray& get_depthwise_methods() { + // FP32 kernels with 3x3 filter. + static DepthwiseArray depthwise_methods{}; + depthwise_methods[0].name = "kai_depthwise_planar_f32p4_3x3_s1_4_sme2_mla"; + depthwise_methods[0].rhs.get_rhs_packed_size = kai_rhs_get_dst_size_dw_conv_pack_x32p1vl_x32_sme; + depthwise_methods[0].rhs.pack = kai_run_rhs_dw_conv_pack_x32p1vl_x32_sme; + depthwise_methods[0].is_supported = cpu_has_sme2; + depthwise_methods[0].filter = {3, 3}; + + const kai_dw_conv_f32_f32_f32p_planar_ukernel& ukernel_f32 = + get_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla(); + depthwise_methods[0].data_type = DataType::FP32; + depthwise_methods[0].acc_type = DataType::FP32; + depthwise_methods[0].depthwise.get_m_step = ukernel_f32.get_m_step; + depthwise_methods[0].depthwise.get_dst_size = ukernel_f32.get_dst_size; + depthwise_methods[0].depthwise.get_dst_offset = ukernel_f32.get_dst_offset; + depthwise_methods[0].depthwise.conv = ukernel_f32.run_dw_conv; + return depthwise_methods; +} + +/// Test reference identification. +struct TestDataId { + MatMulShape in_shape; + MatMulShape wei_shape; + MatMulShape out_shape; + Padding2D pad; + DataType dt; + DataType dt_acc; + float clamp_rate; + + struct Hash { + size_t operator()(const TestDataId& test_id) const { + return // + (MatMulShape::Hash{}(test_id.in_shape) << 0) ^ // + (Padding2D::Hash{}(test_id.pad) << 1) ^ // + (std::hash{}(test_id.clamp_rate) << 2); // + } + }; + +private: + friend bool operator==(const TestDataId& lhs, const TestDataId& rhs) { + return // + lhs.in_shape == rhs.in_shape && // + lhs.pad == rhs.pad && // + lhs.clamp_rate == rhs.clamp_rate; // + } +}; + +/// Test reference data +struct TestData { + Buffer lhs; ///< LHS input matrix + Buffer rhs; ///< RHS input matrix + Buffer bias; ///< Bias vector + Buffer out; ///< Reference depthwise result + Buffer padding; ///< Padding buffer + Range clamp_range; ///< Clamp range +}; + +/// Generate reference data, caches it. +struct ReferenceGenerator { + /// Retrieve reference data for the provided test identification + static const TestData& get_test_reference(const TestDataId test_id) { + static std::unordered_map m_data; + if (const auto itr = m_data.find(test_id); itr != end(m_data)) { + return itr->second; + } + + return m_data[test_id] = generate_reference(test_id); + } + +private: + /// Return incremented seed value + static size_t get_seed() { + static size_t seed = 0; + return seed++; + } + + /// Generate reference data. + static TestData generate_reference(const TestDataId& test_id) { + const auto& [in_shape, wei_shape, out_shape, pad, dt, acc_dt, clamp_rate] = test_id; + + // Generate random input data + Buffer lhs = fill_matrix_random(in_shape.m, in_shape.n * in_shape.k, DataFormat(dt), get_seed()); + Buffer rhs = fill_matrix_random(wei_shape.m, wei_shape.n * wei_shape.k, DataFormat(dt), get_seed()); + Buffer bias = fill_matrix_random(1, out_shape.k, DataFormat(dt), get_seed()); + + // Call reference function + Buffer out = depthwise_reference( + 1, in_shape.m, in_shape.n, in_shape.k, wei_shape.m, wei_shape.n, lhs.data(), rhs.data(), bias.data(), pad); + + const auto [min, max] = + find_clamp_range(dt, out.data(), out_shape.m * out_shape.n * out_shape.k, 1.0F - clamp_rate); + Buffer out_clamped = clamp(dt, out.data(), out_shape.m * out_shape.n * out_shape.k, min, max); + + // Populate reference data + TestData test_reference; + test_reference.lhs = std::move(lhs); + test_reference.rhs = std::move(rhs); + test_reference.bias = std::move(bias); + test_reference.out = std::move(out_clamped); + test_reference.clamp_range = {min, max}; + return test_reference; + }; +}; + +/// Perform RHS packing for depthwise +Buffer pack_rhs(const RhsPackDepthwiseKernel& kernel, const MatMulShape& shape, const TestData& reference) { + // Calculate size, and allocate buffer + const size_t dst_size = kernel.get_rhs_packed_size(shape.m, shape.n, shape.k); + Buffer dst(dst_size); + + // RHS Pack API is subject to change. + kernel.pack(reference.rhs.data(), dst.data(), shape.m, shape.n, shape.m, shape.n, shape.k); + return dst; +} + +Buffer dw_conv( + const DepthwisePlanarKernel& kernel, const Rect& portion, const MatMulShape& in_shape, const MatMulShape& out_shape, + const Padding2D pad, const TestData& reference, const Buffer& rhs_packed, Range clamp_range, DataType type) { + KAI_UNUSED(type); + + const size_t dst_size = kernel.get_dst_size(out_shape.m, out_shape.n, out_shape.k); + Buffer dst(dst_size); + + const size_t dt_size_bytes = data_type_size_in_bits(type) / 8; + const size_t stride_in_row = in_shape.n * in_shape.k * dt_size_bytes; + const size_t stride_out_row = out_shape.n * out_shape.k * dt_size_bytes; + const size_t stride_col = dt_size_bytes; + + // Loop the following. M-Step rows are handled at a time. + for (size_t out_row = portion.start_row(); out_row < portion.end_row(); out_row += kernel.get_m_step()) { + const int start_in_row = out_row - pad.top; + const size_t pad_top = (start_in_row < 0) ? (-start_in_row) : 0; + const size_t in_row = (start_in_row < 0) ? 0 : start_in_row; + + const size_t valid_input_rows = (in_row < in_shape.m) ? (in_shape.m - in_row) : 0; + const size_t valid_out_rows = (out_shape.m - out_row); + + kernel.conv( + reference.lhs.data() + (in_row * stride_in_row), stride_in_row, stride_col, pad_top, pad.left, + valid_input_rows, rhs_packed.data(), reference.bias.data(), dst.data() + (out_row * stride_out_row), + stride_col, stride_out_row, valid_out_rows, clamp_range.min, clamp_range.max, 0.f); + } + + return dst; +} +} // namespace + +/// End-to-end test for depthwise kernels +TEST_P(DepthwisePlanarTest, Output) { + const auto& [method, in_shape, padding, out_portion, clamp_rate] = GetParam(); + if (not method.is_supported()) { + GTEST_SKIP() << "Unsupported CPU feature"; + } + + // Calculate Shapes. + int out_height = (in_shape.m + padding.top + padding.bottom + 1 - method.filter.first); + int out_width = (in_shape.n + padding.left + padding.right + 1 - method.filter.second); + if (out_height <= 0 || out_width <= 0) // Check shapes valid + { + GTEST_SKIP() << "Invalid DCONV Shape"; + } + + const size_t dt_size_bytes = data_type_size_in_bits(method.data_type) / 8; + MatMulShape wei_shape = {method.filter.first, method.filter.second, in_shape.k}; + MatMulShape out_shape = {static_cast(out_height), static_cast(out_width), (in_shape.k)}; + + // 1. Calculate reference. + const TestData& test_data = ReferenceGenerator::get_test_reference( + {in_shape, wei_shape, out_shape, padding, method.data_type, method.acc_type, clamp_rate}); + + // 2. Run DW Conv kernels and compare. + Buffer rhs_packed = pack_rhs(method.rhs, wei_shape, test_data); + const Rect portion = out_portion.compute_portion( + out_shape.m, out_shape.n * out_shape.k, method.depthwise.get_m_step(), (rhs_packed.size() / dt_size_bytes)); + + Buffer out = dw_conv( + method.depthwise, portion, in_shape, out_shape, padding, test_data, rhs_packed, test_data.clamp_range, + method.data_type); + + DefaultMismatchHandler handler(0, 0.0001, 0, 0.001); + const auto success = compare( + out.data(), test_data.out.data(), DataType::FP32, out_shape.m, out_shape.n * out_shape.k, portion, handler); + ASSERT_TRUE(success); +} + +/// Name generator for test case +[[maybe_unused]] static void PrintTo(const DepthwiseParamsParams& param, std::ostream* os) { + const auto& [method, shape, padding, portion, clamp_rate] = param; + *os << method.name << "__"; + PrintTo(shape, os); + PrintTo(padding, os); + *os << "__clamp_rate_" << static_cast(clamp_rate * 100) << "__"; + PrintTo(portion, os); +} + +/// Test parameter listing +INSTANTIATE_TEST_SUITE_P( + Depthwise, DepthwisePlanarTest, + testing::Combine( + testing::ValuesIn(get_depthwise_methods()), // + testing::ValuesIn({ + // clang-format off + // IN_HEIGHT, IN_WIDTH, IN_CHANNELS + MatMulShape{ 4, 4, 1}, // + MatMulShape{ 8, 4, 16}, // + MatMulShape{ 96, 33, 37}, // + MatMulShape{ 99, 22, 51}, // + MatMulShape{ 127, 127, 127}, // + // clang-format on + }), + testing::ValuesIn( + {Padding2D{0, 0, 0, 0}, Padding2D{0, 1, 0, 1}, Padding2D{1, 1, 1, 1}, Padding2D{5, 11, 7, 3}}), // + testing::ValuesIn({ + // clang-format off + // (Start row , start col , height , width) + MatrixPortion( 0 , 0 , 1 , 1 ), // Full matrix. + // clang-format on + }), + testing::ValuesIn(std::initializer_list{1.f})), // + testing::PrintToStringParamName()); + +} // namespace kai::test