From 940376ad271eb52c27337ffd940012ba6a40e1d9 Mon Sep 17 00:00:00 2001 From: Mohammed Suhail Munshi Date: Thu, 26 Jun 2025 02:08:30 +0100 Subject: [PATCH 1/6] Add SME Depthwise Convolution Planar Kernels and Reference implementation. - Add example use of planar kernels with padding. - Planar kernels produce four output rows given an input pointer with stride and padding arguments - Adds packing function for weights packing with depthwise kernels. - Packing function packs based on vector length along channel dim. Signed-off-by: Mohammed Suhail Munshi --- CMakeLists.txt | 4 +- .../CMakeLists.txt | 39 ++ .../dconv.cpp | 339 ++++++++++++ ...2_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.c | 513 ++++++++++++++++++ ...2_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.h | 58 ++ .../kai_rhs_dw_conv_pack_x32p1vl_x32_sme.c | 47 ++ .../kai_rhs_dw_conv_pack_x32p1vl_x32_sme.h | 43 ++ test/reference/depthwise_conv.cpp | 76 +++ test/reference/depthwise_conv.hpp | 39 ++ 9 files changed, 1157 insertions(+), 1 deletion(-) create mode 100644 examples/dw_conv_f32_f32_f32p_planar_sme2/CMakeLists.txt create mode 100644 examples/dw_conv_f32_f32_f32p_planar_sme2/dconv.cpp create mode 100644 kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.c create mode 100644 kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.h create mode 100644 kai/ukernels/dw_conv/pack/kai_rhs_dw_conv_pack_x32p1vl_x32_sme.c create mode 100644 kai/ukernels/dw_conv/pack/kai_rhs_dw_conv_pack_x32p1vl_x32_sme.h create mode 100644 test/reference/depthwise_conv.cpp create mode 100644 test/reference/depthwise_conv.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 6fe93eff..23d53f19 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -319,6 +319,7 @@ set(KLEIDIAI_FILES_SME2 kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c + kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.c kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c ) @@ -432,11 +433,12 @@ if(KLEIDIAI_BUILD_TESTS) test/reference/reduce.cpp test/reference/reorder.cpp test/reference/transpose.cpp + test/reference/depthwise_conv.cpp ) target_compile_options(kleidiai_test_framework PUBLIC ${KLEIDIAI_WARNING_FLAGS} - PUBLIC $<$>:-march=armv8-a${KLEIDIAI_INTERNAL_EXTRA_ARCH}> + PUBLIC $<$>:-march=armv8-a> ) if(MSVC) diff --git a/examples/dw_conv_f32_f32_f32p_planar_sme2/CMakeLists.txt b/examples/dw_conv_f32_f32_f32p_planar_sme2/CMakeLists.txt new file mode 100644 index 00000000..d271b86a --- /dev/null +++ b/examples/dw_conv_f32_f32_f32p_planar_sme2/CMakeLists.txt @@ -0,0 +1,39 @@ +# +# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +# +# SPDX-License-Identifier: Apache-2.0 +# + +cmake_minimum_required(VERSION 3.16) + +project(main) + +set(CMAKE_CXX_STANDARD 17) +set(KAI_PATH ${CMAKE_CURRENT_SOURCE_DIR}/../../) +set(KAI_BUILD ${KAI_PATH}/build) +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +# Create a tiny interface target that only carries the desired ISA flags +add_library(sve_flags INTERFACE) +target_compile_options(sve_flags INTERFACE + -march=armv8.2-a+sve+sve2 + -fno-tree-vectorize) + +include_directories(${KAI_PATH}) + +# Files requires to build the executable +add_executable( + main dconv.cpp + "${KAI_PATH}/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.c" + "${KAI_PATH}/kai/ukernels/dw_conv/pack/kai_rhs_dw_conv_pack_x32p1vl_x32_sme.c" + ) + +target_compile_options(main + PRIVATE "-march=armv8.2-a+sve+sve2;-fno-tree-vectorize" +) + +target_link_libraries(main PRIVATE sve_flags) + +target_compile_definitions(main + PRIVATE $<$:KAI_DEBUG> +) diff --git a/examples/dw_conv_f32_f32_f32p_planar_sme2/dconv.cpp b/examples/dw_conv_f32_f32_f32p_planar_sme2/dconv.cpp new file mode 100644 index 00000000..61d438f8 --- /dev/null +++ b/examples/dw_conv_f32_f32_f32p_planar_sme2/dconv.cpp @@ -0,0 +1,339 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "kai/kai_common.h" +#include "kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.h" +#include "kai/ukernels/dw_conv/pack/kai_rhs_dw_conv_pack_x32p1vl_x32_sme.h" + +using VEC_TYPE = std::vector; + +namespace { +constexpr float clamp_min = std::numeric_limits::lowest(); +constexpr float clamp_max = std::numeric_limits::max(); + +struct Padding2D { + size_t left = 0; + size_t right = 0; + size_t bottom = 0; + size_t top = 0; +}; + +struct Shape { + size_t n = 1; + size_t h = 1; + size_t w = 1; + size_t c = 1; + + [[nodiscard]] auto size() const -> size_t { + return n * h * w * c; + } + + friend std::ostream& operator<<(std::ostream& os, const Shape& shape) { + os << " [ " << shape.n << " , " << shape.h << " ," << shape.w << " , " << shape.c << " ] "; + return os; + } + + constexpr const std::size_t& operator[](std::size_t idx) const { + switch (idx) { + case 0: + return n; + case 1: + return h; + case 2: + return w; + case 3: + return c; + default: + throw std::out_of_range("Shape-index out of range (0-3)"); + } + } +}; + +#ifdef KAI_DEBUG +void print_tensor(const Shape& shape, const char* name, const float* src) { + std::cout << "\n\n" << name << " = [\n"; + for (size_t n = 0; n < shape.n; n++) { + std::cout << "\n"; + for (size_t y = 0; y < shape.h; ++y) { + std::cout << " ["; + for (size_t x = 0; x < shape.w; x++) { + std::cout << "["; + for (size_t c = 0; c < shape.c; c++) { + if (c != 0) std::cout << " , "; + std::cout << std::setprecision(3) << std::fixed + << src[n * shape.h * shape.w * shape.c + y * shape.w * shape.c + x * shape.c + c]; + } + std::cout << "] "; + } + std::cout << ("],\n"); + } + } + std::cout << ("]\n\n"); +} + +void print_raw(const Shape& shape, const char* name, const VEC_TYPE& src) { + std::cout << "\n\n" << name << " = ["; + for (size_t i = 0; i < shape.size(); i++) { + if (i != 0) std::cout << " , "; + std::cout << std::setprecision(1) << std::fixed << (float)src[i]; + } + std::cout << "]\n"; +} + +#endif +/// Fills the matrix with incremental values according to the provided weight. +/// @param[in] size Total number of elements to fill in passed vector;. +/// @param[in] dst Vector representing a tensor to fill. +/// @param[in] weight A weight value to increment by. +void fill_matrix(size_t size, VEC_TYPE& dst, const float weight) { + for (size_t i = 0; i < size; i++) { + dst[i] = float((10 * i) * weight); + } +} + +void fill_matrix_uniform(size_t size, VEC_TYPE& dst, const float weight) { + for (size_t i = 0; i < size; i++) { + dst[i] = float(weight); + } +} + +/// Depthwise Convolution - Expects NHWC dataformat. Padding value is 0. +/// +/// @tparam T Data type. +/// +/// @param[in] batches Batch dimension of feature map. +/// @param[in] in_height height of feature map. +/// @param[in] in_width width of feature map. +/// @param[in] channels Number of channels in feature map. +/// @param[in] filter_height Height dimension in filter. +/// @param[in] filter_width Width of convolution filter. +/// @param[in] feature_map Ptr to start of feature map. +/// @param[in] weights Ptr to start of weights buffer/tensor. +/// @param[in] bias Ptr to start of bias buffer. +/// @param[in] clamp_min float value to clamp output to (lower bound). +/// @param[in] clamp_max float value to clamp output to (upper bound). +/// +/// @return The result data buffer. +template +void depthwise_reference( + const size_t batches, const size_t in_height, const size_t in_width, const size_t channels, + const size_t filter_height, const size_t filter_width, const void* feature_map, const void* weights, + const void* bias, void* out, float clamp_min, float clamp_max, const Padding2D pad) { + // Calculate output dims (Padding = Valid). + const size_t out_height = (in_height + pad.top + pad.bottom + 1 - filter_height); + const size_t out_width = in_width + pad.left + pad.right + 1 - filter_width; + const size_t out_size = out_height * out_width * batches * channels; + + // We accumulate in FP32 and clamp and cast to return type later. + std::vector acc(out_size, 0.0f); + + for (size_t b = 0; b < batches; ++b) { + for (size_t out_h = 0; out_h < out_height; ++out_h) { + for (size_t out_w = 0; out_w < out_width; ++out_w) { + const size_t out_base = ((b * out_height + out_h) * out_width + out_w) * channels; + + // Apply filter to feature map. + for (size_t ic = 0; ic < channels; ++ic) { + float sum = 0.0f; + + for (size_t kernel_h = 0; kernel_h < filter_height; ++kernel_h) { + // Determine if input height bounds. If not, then this is padding. + const int in_y = static_cast(out_h + kernel_h) - static_cast(pad.top); + if (in_y < 0 || in_height <= static_cast(in_y)) continue; + + for (size_t kernel_w = 0; kernel_w < filter_width; ++kernel_w) { + // Determine if in input width bounds, if not this is padding. + const int in_x = static_cast(out_w + kernel_w) - static_cast(pad.left); + if (in_x < 0 || in_width <= static_cast(in_x)) continue; + + auto in_idx = ((b * in_height + in_y) * in_width + in_x) * channels + ic; + auto weights_idx = ((kernel_h * filter_width) + kernel_w) * channels + ic; + + auto wei_value = reinterpret_cast(weights)[weights_idx]; + auto in_value = reinterpret_cast(feature_map)[in_idx]; + + // Perform actual accumulation and store in output vector + sum += in_value * wei_value; + } + } + + auto out_idx = out_base + ic; + float bias_value = reinterpret_cast(bias)[ic]; + sum = sum + bias_value; + sum = std::clamp(sum, clamp_min, clamp_max); + reinterpret_cast(out)[out_idx] = sum; + } + } + } + } +} +} // namespace + +int main() { + // FIXED : We should not care about batches for this example, those are just loops. + // All other variables below are kernel specific. + const int batches = 1; + enum class pad_mode { SAME, VALID }; + + size_t total_test = 0; + for (pad_mode pad : {pad_mode::SAME, pad_mode::VALID}) { + for (size_t width = 128; width < 129; width += 2) { + for (size_t height = 141; height < 142; height += 2) { + for (size_t channels = 1; channels < 64; channels += 7) { + total_test++; + const int filter_height = 3; + const int filter_width = 3; + const int depth_multiplier = 1; // Only dm =1 supported. + + assert(filter_height > 1 && filter_width > 1); + + const size_t pad_total_height = (pad == pad_mode::SAME) ? filter_height - 1 : 0; + const size_t pad_total_width = (pad == pad_mode::SAME) ? filter_width - 1 : 0; + Padding2D padding; + padding.top = pad_total_height / 2; + padding.left = pad_total_width / 2; + padding.right = pad_total_width - padding.left; + padding.bottom = pad_total_height - padding.top; + + Shape in_shape{batches, height, width, channels}; + Shape wei_shape{filter_height, filter_width, channels, depth_multiplier}; + Shape bias_shape{depth_multiplier * channels}; + Shape out_shape{ + batches, (height + padding.top + padding.bottom + 1 - filter_height), + (width + padding.left + padding.right + 1 - filter_width), channels * depth_multiplier}; + + VEC_TYPE input(in_shape.size(), 0.0f); + VEC_TYPE weights(wei_shape.size(), 0.1f); + VEC_TYPE bias(bias_shape.size(), 0.0f); + VEC_TYPE out(out_shape.size(), 0.0f); + VEC_TYPE ref(out_shape.size(), 0.0f); + + fill_matrix(in_shape.size(), input, 0.01f); + fill_matrix(wei_shape.size(), weights, 0.02f); + fill_matrix_uniform(bias_shape.size(), bias, 0.f); + + // For testing using Python. +#ifdef KAI_DEBUG + { + std::cout << "\n#BEGIN PARAMS\n"; + std::cout << "\nbatch, height, width, channels = " << batches << ", " << height << ", " << width + << ", " << channels << std::endl; + std::cout << "\nfilter_height, filter_width = " << filter_height << ", " << filter_width + << std::endl; + print_raw(in_shape, "Inputs ", input); + print_raw(wei_shape, "Weights ", weights); + print_raw(bias_shape, "Bias ", bias); + std::cout << "\npad_top, pad_bottom = " << padding.top << ", " << padding.bottom << std::endl; + std::cout << "\npad_left, pad_right = " << padding.left << ", " << padding.right << std::endl + << std::endl; + std::cout << "\n#END PARAMS\n"; + } +#endif // KAI_DEBUG + + // ------------------------------------------------- + // 1. Calculate Reference Depthwise Values. + // ------------------------------------------------- + depthwise_reference( + batches, height, width, channels, filter_height, filter_width, (const void*)input.data(), + (const void*)weights.data(), (const void*)bias.data(), (void*)ref.data(), clamp_min, clamp_max, + padding); + + // ------------------------------------------------- + // 2. Pack weights for use in SME Kernel + // ------------------------------------------------- + // const size_t vec_length = kai_get_sme_vector_length_u32(); + const size_t packed_size = + kai_rhs_get_dst_size_dw_conv_pack_x32p1vl_x32_sme(filter_height, filter_width, channels) / + sizeof(float); + + // Run packing kernel. + std::vector weights_packed(packed_size); + kai_run_rhs_dw_conv_pack_x32p1vl_x32_sme( + weights.data(), weights_packed.data(), filter_height, filter_width, wei_shape[0], wei_shape[1], + channels); + +#ifdef KAI_DEBUG + const size_t vec_length = kai_get_sme_vector_length_u32(); + // Print packed weights - 1VL per row. + print_tensor( + {1, (weights_packed.size() / vec_length), 1, vec_length}, + "\n Weights Packed : ", weights_packed.data()); +#endif + // ------------------------------------------------- + // 3. Kernel handles 4 rows at a time. + // We must adjust output pointer to the start of each 4 rows to be output and the input pointer + // passed to kernel. Kernel also expects 4 element array of strides and vec_lengths which are + // determined by if + // the specified output rows are valid or not. If not, 0 is passed in place of value. + // ------------------------------------------------- + constexpr size_t rows_handled = 4; // no of rows kernel handles each time. + for (size_t out_row = 0; out_row < out_shape.h; out_row += rows_handled) { + // printf("\nCalculating output rows : {%zu} - {%zu}\n", out_row, out_row + rows_handled); + + // Variables below used to calculate start of input pointer. + const int start_in_row = out_row - padding.top; + const size_t pad_top = (start_in_row < 0) ? (-start_in_row) : 0; + const size_t in_row = (start_in_row < 0) ? 0 : start_in_row; + + // Calculate row strides for pointer. + const size_t in_row_stride_elements = (width * channels); + const size_t out_row_stride_elements = (out_shape.w * out_shape.c); + + // Number of input rows that can be read, number of output rows to calculate. + const size_t valid_input_rows = (in_row < height) ? (height - in_row) : 0; + const size_t valid_out_rows = (out_shape.h - out_row); + + // Increment output/input pointers according to tile being calculated. + const auto inptr = input.data() + (in_row * in_row_stride_elements); + auto outptr = out.data() + (out_row * out_row_stride_elements); + + // NOTE: Kernel expects strides to be passed as bytes. + kai_run_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla( + inptr, in_row_stride_elements * sizeof(float), channels * sizeof(float), pad_top, + padding.left, valid_input_rows, weights_packed.data(), bias.data(), outptr, + out_shape.c * sizeof(float), out_row_stride_elements * sizeof(float), valid_out_rows, + clamp_min, clamp_max, 0.0f); + } + +#ifdef KAI_DEBUG + // Print outputs + print_tensor(out_shape, "Reference : ", reinterpret_cast(ref.data())); + print_tensor(out_shape, "\n\n Actual : ", out.data()); + std::cout << "\n\nOut shape : " << out_shape << std::endl; +#endif // KAI_DEBUG + + /// Check for mismatches in the tests. + size_t mismatches = 0; + for (size_t i = 0; i < out_shape.size(); i++) { + float ref_value = ref[i]; + // FP32 rel tolerance - allows deviations of up to 0.05% + const auto err = (std::abs(out[i] - ref_value) / std::abs(ref_value)); + if (err > 0.0005) { + std::cout << "Mismatches(Expected:Actual)" << ref_value << " : " << out[i] << std::endl; + mismatches++; + } + if (mismatches > 0) { + std::cout << "\nNumber of mismatches: " << mismatches << std::endl; + } + } + } + } + } + } + std::cout << "total tests run: " << total_test << std::endl; +} diff --git a/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.c b/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.c new file mode 100644 index 00000000..605f47e9 --- /dev/null +++ b/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.c @@ -0,0 +1,513 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural features check. + +#include "kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.h" + +#include +#include + +#include "kai/kai_common.h" + +// Number of rows iterated through each call. +static const size_t kai_mr = 4; + +// Filter/Kernel Height and width. +static const size_t kai_kh = 3; +static const size_t kai_kw = 3; + +size_t kai_get_dst_size_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla( + const size_t out_height, const size_t out_width, const size_t num_channels) { + return out_height * out_width * num_channels * sizeof(float); +} + +size_t kai_get_m_step_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla() { + return kai_mr; +} + +size_t kai_get_filter_height_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla() { + return kai_kh; +} + +size_t kai_get_filter_width_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla() { + return kai_kw; +} + +size_t kai_get_dst_offset_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla( + size_t out_row_idx, size_t stride_out_row) { + KAI_ASSUME(out_row_idx % kai_mr == 0); + return (out_row_idx * stride_out_row); +} + +void kai_run_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla( + const void* inptr, size_t stride_in_row, size_t stride_in_col, unsigned int pad_top, unsigned int pad_left, + unsigned int valid_input_rows, const void* weights, const void* bias, void* outptr_start, size_t stride_out_col, + size_t stride_out_row, unsigned int valid_out_rows, float act_min, float act_max, const float pad_value) { + KAI_ASSUME(inptr != NULL); + KAI_ASSUME(weights != NULL); + KAI_ASSUME(outptr_start != NULL); + KAI_ASSUME(valid_out_rows != 0); + KAI_ASSUME(pad_value == 0.0f); + + typedef struct { + const void* inptr; + size_t ld_in_vl; + long unsigned int pad_top, pad_bottom, pad_left; + const void* weights; + const void* bias; + long unsigned int input_cols, output_cols; + void** outptrs; + const size_t* ld_out_cols; + const size_t* ld_out_vls; + long unsigned int current_channel, n_channels; + float clamp_min, clamp_max; + } Args; + + // Create padding row. + float pad_row[KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(float)] = {0}; + + // Calculate bottom padding offset. + long unsigned int pad_bottom = (6U < (pad_top + valid_input_rows)) ? (0U) : (6U - (pad_top + valid_input_rows)); + + // Leading dims calculated using input parameters. + size_t ld_in_vl = kai_get_sme_vector_length_u32(); + size_t ld_in_row = stride_in_row / sizeof(float); + size_t ld_in_col = stride_in_col / sizeof(float); + + // Calculate matrix dimensions using the strides provided. + const size_t num_channels = stride_out_col / sizeof(float); + const size_t output_cols = stride_out_row / (sizeof(float) * num_channels); + const size_t valid_input_cols = stride_in_row / (sizeof(float) * num_channels); + + // These arrays are initilised as if they were invalid/padded rows, then set if out row is valid + void* outptrs[4] = {pad_row, pad_row, pad_row, pad_row}; + size_t outlds[4] = {0}; + size_t outvllds[4] = {0}; + + for (unsigned int i = 0; i < 4; i++) { + if (i < valid_out_rows) { + outptrs[i] = (uint8_t*)outptr_start + (i * stride_out_row); + outlds[i] = num_channels; + outvllds[i] = ld_in_vl; + } + } + + Args args; + args.inptr = inptr; + args.ld_in_vl = ld_in_vl; + args.pad_top = pad_top; + args.pad_bottom = pad_bottom; + args.pad_left = pad_left; + args.weights = weights; + args.bias = bias; + args.input_cols = valid_input_cols; + args.output_cols = output_cols; + args.outptrs = outptrs; + args.ld_out_cols = outlds; + args.ld_out_vls = outvllds; + args.current_channel = 0; + args.n_channels = num_channels; + args.clamp_min = act_min; + args.clamp_max = act_max; + + __asm__ __volatile__( + "ldr x7, [%x[args], %[offsetof_Args_pad_bottom]]\n" + "mov x20, #0x6\n" + ".inst 0xd503477f // SMSTART ZA\n" + "ldr x17, [%x[args], %[offsetof_Args_pad_top]]\n" + "ptrue p2.b\n" + ".inst 0x25207812 // ptrue pn10.b\n" + "ldr x16, [%x[args], %[offsetof_Args_n_channels]]\n" + "ld1rw { z3.s }, p2/Z, [%x[args], %[offsetof_Args_clamp_min]]\n" + "sub x20, x20, x7\n" + "ldr x15, [%x[args], %[offsetof_Args_current_channel]]\n" + "ld1rw { z9.s }, p2/Z, [%x[args], %[offsetof_Args_clamp_max]]\n" + "whilelt p1.s, XZR, x16\n" + "whilelt p9.s, XZR, x20\n" + "whilelt p8.s, XZR, x17\n" + "eor p8.b, p2/Z, p8.b, p9.b\n" + "1:" // Channel loop + "ldr x20, [%x[args], %[offsetof_Args_bias]]\n" + "fmov z16.s, #0x0\n" + "cbz x20, 2f\n" + "ld1w { z16.s }, p1/Z, [x20, x15, LSL #2]\n" + "2:" // Load bias: Done + "ldr x14, [%x[args], %[offsetof_Args_input_cols]]\n" + "mov x23, #0x6\n" + "add x20, x17, x7\n" + "mov z17.d, z16.d\n" + "ldr x22, [%x[args], %[offsetof_Args_weights]]\n" + "lsl x21, %x[ld_in_row], #0x2\n" + "mov z18.d, z16.d\n" + "mov z19.d, z16.d\n" + "ldr x13, [%x[args], %[offsetof_Args_inptr]]\n" + "mov x8, #0x0\n" + "sub x23, x23, x20\n" + "sub x20, x14, #0x1\n" + "ldr x11, [%x[args], %[offsetof_Args_output_cols]]\n" + ".inst 0xa0404ace // ld1w { z14.s-z15.s }, pn10.b/Z, [x22]\n" + "orr x20, x20, %x[ld_in_col], LSL #18\n" + "ld1w { z11.s }, p2/Z, [x22, #2, MUL VL]\n" + "addvl x22, x22, #3\n" + "orr x20, x16, x20, LSL #20\n" + ".inst 0xa0404acc // ld1w { z12.s-z13.s }, pn10.b/Z, [x22]\n" + "lsl x20, x20, #0x2\n" + "madd x21, x21, x17, x13\n" + "ld1w { z0.s }, p2/Z, [x22, #2, MUL VL]\n" + "addvl x22, x22, #3\n" + ".inst 0xa0404ac4 // ld1w { z4.s-z5.s }, pn10.b/Z, [x22]\n" + "ld1w { z7.s }, p2/Z, [x22, #2, MUL VL]\n" + "3:" // Issue prefetches + "subs x23, x23, #0x1\n" + ".inst 0xf8b44abc // rprfm pldstrm, x20, [x21]\n" + "add x21, x21, %x[ld_in_col], LSL #2\n" + "bgt 3b\n" + "ldr x22, [%x[args], %[offsetof_Args_outptrs]]\n" + "lsl x21, %x[ld_in_row], #0x2\n" + ".inst 0xc0040e00 // mova za.d[x8, #0], { z16.d-z19.d }\n" + "mov x10, #0x2\n" + "ldr x20, [%x[args], %[offsetof_Args_ld_out_cols]]\n" + "msub x13, x17, x21, x13\n" + ".inst 0xc0040e01 // mova za.d[x8, #1], { z16.d-z19.d }\n" + "ldr x21, [%x[args], %[offsetof_Args_pad_left]]\n" + ".inst 0xc0040e02 // mova za.d[x8, #2], { z16.d-z19.d }\n" + "ldp x9, x28, [x22], #0x10\n" + "ldp x27, x26, [x20], #0x10\n" + "ldp x25, x24, [x22], #0x10\n" + "ldp x23, x22, [x20], #0x10\n" + "cbz x21, 5f\n" + "cmp x21, x10\n" + "csel x20, x21, x10, LT\n" + "sub x21, x21, x20\n" + "sub x10, x10, x20\n" + "cbz x21, 5f\n" + ".inst 0xc0060c1c // mova { z28.d-z31.d }, za.d[x8, #0]\n" + "sub x11, x11, x21\n" + ".inst 0xc1a9c87c // fclamp { z28.s-z31.s }, z3.s, z9.s\n" + "4:" // Left padding + "subs x21, x21, #0x1\n" + "st1w { z28.s }, p1, [x9]\n" + "add x9, x9, x27, LSL #2\n" + "st1w { z29.s }, p1, [x28]\n" + "add x28, x28, x26, LSL #2\n" + "st1w { z30.s }, p1, [x25]\n" + "add x25, x25, x23, LSL #2\n" + "st1w { z31.s }, p1, [x24]\n" + "add x24, x24, x22, LSL #2\n" + "bgt 4b\n" + "5:" // Left padding: End + "adds XZR, x17, x7\n" + "bne 10f\n" + "cbz x10, 8f\n" + "cmp x10, #0x1\n" + "sub x14, x14, x10\n" + "beq 7f\n" + "add x20, x13, %x[ld_in_row], LSL #2\n" + "ld1w { z22.s }, p1/Z, [x13]\n" + "add x13, x13, %x[ld_in_col], LSL #2\n" + "ld1w { z23.s }, p1/Z, [x20]\n" + "add x20, x20, %x[ld_in_row], LSL #2\n" + "ld1w { z24.s }, p1/Z, [x20]\n" + "add x20, x20, %x[ld_in_row], LSL #2\n" + "ld1w { z25.s }, p1/Z, [x20]\n" + "add x20, x20, %x[ld_in_row], LSL #2\n" + "ld1w { z26.s }, p1/Z, [x20]\n" + "add x20, x20, %x[ld_in_row], LSL #2\n" + "ld1w { z27.s }, p1/Z, [x20]\n" + ".inst 0xc13e1ac0 // fmla za.s[x8, 0], { z22.s-z25.s }, z14.s\n" + ".inst 0xc13c1ae0 // fmla za.s[x8, 0], { z23.s-z26.s }, z12.s\n" + ".inst 0xc1341b00 // fmla za.s[x8, 0], { z24.s-z27.s }, z4.s\n" + "7:" // Unpadded: 1 priming loads + "add x20, x13, %x[ld_in_row], LSL #2\n" + "ld1w { z24.s }, p1/Z, [x13]\n" + "add x13, x13, %x[ld_in_col], LSL #2\n" + "ld1w { z25.s }, p1/Z, [x20]\n" + "add x20, x20, %x[ld_in_row], LSL #2\n" + "ld1w { z26.s }, p1/Z, [x20]\n" + "add x20, x20, %x[ld_in_row], LSL #2\n" + "ld1w { z27.s }, p1/Z, [x20]\n" + "add x20, x20, %x[ld_in_row], LSL #2\n" + "ld1w { z28.s }, p1/Z, [x20]\n" + "add x20, x20, %x[ld_in_row], LSL #2\n" + "ld1w { z29.s }, p1/Z, [x20]\n" + ".inst 0xc13f1b00 // fmla za.s[x8, 0], { z24.s-z27.s }, z15.s\n" + ".inst 0xc13e1b01 // fmla za.s[x8, 1], { z24.s-z27.s }, z14.s\n" + ".inst 0xc13d1b20 // fmla za.s[x8, 0], { z25.s-z28.s }, z13.s\n" + ".inst 0xc13c1b21 // fmla za.s[x8, 1], { z25.s-z28.s }, z12.s\n" + ".inst 0xc1351b40 // fmla za.s[x8, 0], { z26.s-z29.s }, z5.s\n" + ".inst 0xc1341b41 // fmla za.s[x8, 1], { z26.s-z29.s }, z4.s\n" + "8:" // Unpadded: 0 priming loads + "cbz x14, 16f\n" + "add x20, x13, %x[ld_in_row], LSL #2\n" + "ld1w { z20.s }, p1/Z, [x13]\n" + "sub x14, x14, #0x1\n" + "ld1w { z21.s }, p1/Z, [x20]\n" + "add x20, x20, %x[ld_in_row], LSL #2\n" + "sub x11, x11, #0x1\n" + "ld1w { z22.s }, p1/Z, [x20]\n" + "add x20, x20, %x[ld_in_row], LSL #2\n" + "cmp x14, x11\n" + "ld1w { z23.s }, p1/Z, [x20]\n" + "add x20, x20, %x[ld_in_row], LSL #2\n" + "csel x21, x14, x11, LT\n" + "ld1w { z24.s }, p1/Z, [x20]\n" + "add x20, x20, %x[ld_in_row], LSL #2\n" + "add x13, x13, %x[ld_in_col], LSL #2\n" + "ld1w { z25.s }, p1/Z, [x20]\n" + "sub x11, x11, x21\n" + "cbz x21, 15f\n" + "9:" // Unpadded: Main loop + ".inst 0xc13b1a80 // fmla za.s[x8, 0], { z20.s-z23.s }, z11.s\n" + "add x20, x13, %x[ld_in_row], LSL #2\n" + "subs x21, x21, #0x1\n" + ".inst 0xc13f1a81 // fmla za.s[x8, 1], { z20.s-z23.s }, z15.s\n" + ".inst 0xc13e1a82 // fmla za.s[x8, 2], { z20.s-z23.s }, z14.s\n" + "ld1w { z20.s }, p1/Z, [x13]\n" + "add x13, x13, %x[ld_in_col], LSL #2\n" + ".inst 0xc1301aa0 // fmla za.s[x8, 0], { z21.s-z24.s }, z0.s\n" + ".inst 0xc13d1aa1 // fmla za.s[x8, 1], { z21.s-z24.s }, z13.s\n" + ".inst 0xc13c1aa2 // fmla za.s[x8, 2], { z21.s-z24.s }, z12.s\n" + "ld1w { z21.s }, p1/Z, [x20]\n" + "add x20, x20, %x[ld_in_row], LSL #2\n" + ".inst 0xc1371ac0 // fmla za.s[x8, 0], { z22.s-z25.s }, z7.s\n" + ".inst 0xc1351ac1 // fmla za.s[x8, 1], { z22.s-z25.s }, z5.s\n" + ".inst 0xc1341ac2 // fmla za.s[x8, 2], { z22.s-z25.s }, z4.s\n" + "ld1w { z22.s }, p1/Z, [x20]\n" + "add x20, x20, %x[ld_in_row], LSL #2\n" + "ld1w { z23.s }, p1/Z, [x20]\n" + "add x20, x20, %x[ld_in_row], LSL #2\n" + ".inst 0xc0060c1c // mova { z28.d-z31.d }, za.d[x8, #0]\n" + "add x8, x8, #0x1\n" + "ld1w { z24.s }, p1/Z, [x20]\n" + "add x20, x20, %x[ld_in_row], LSL #2\n" + ".inst 0xc0040e02 // mova za.d[x8, #2], { z16.d-z19.d }\n" + "ld1w { z25.s }, p1/Z, [x20]\n" + ".inst 0xc1a9c87c // fclamp { z28.s-z31.s }, z3.s, z9.s\n" + "st1w { z28.s }, p1, [x9]\n" + "add x9, x9, x27, LSL #2\n" + "st1w { z29.s }, p1, [x28]\n" + "add x28, x28, x26, LSL #2\n" + "st1w { z30.s }, p1, [x25]\n" + "add x25, x25, x23, LSL #2\n" + "st1w { z31.s }, p1, [x24]\n" + "add x24, x24, x22, LSL #2\n" + "bgt 9b\n" + "b 15f\n" + "10:" // Padded + "cbz x10, 13f\n" + "cmp x10, #0x1\n" + "sub x14, x14, x10\n" + "beq 12f\n" + "mov x12, #0x0\n" + "add x20, x13, %x[ld_in_row], LSL #2\n" + ".inst 0x25304500 // psel p0.s, p1.s/Z, p8.s[w12]\n" + "ld1w { z23.s }, p0/Z, [x13]\n" + ".inst 0x25704500 // psel p0.s, p1.s/Z, p8.s[w12, #1]\n" + "add x13, x13, %x[ld_in_col], LSL #2\n" + "ld1w { z24.s }, p0/Z, [x20]\n" + ".inst 0x25b04500 // psel p0.s, p1.s/Z, p8.s[w12, #2]\n" + "add x20, x20, %x[ld_in_row], LSL #2\n" + "ld1w { z25.s }, p0/Z, [x20]\n" + ".inst 0x25f04500 // psel p0.s, p1.s/Z, p8.s[w12, #3]\n" + "add x20, x20, %x[ld_in_row], LSL #2\n" + "mov x12, #0x4\n" + "ld1w { z26.s }, p0/Z, [x20]\n" + ".inst 0x25304500 // psel p0.s, p1.s/Z, p8.s[w12]\n" + "add x20, x20, %x[ld_in_row], LSL #2\n" + ".inst 0xc13e1ae0 // fmla za.s[x8, 0], { z23.s-z26.s }, z14.s\n" + "ld1w { z27.s }, p0/Z, [x20]\n" + ".inst 0x25704500 // psel p0.s, p1.s/Z, p8.s[w12, #1]\n" + "add x20, x20, %x[ld_in_row], LSL #2\n" + "ld1w { z28.s }, p0/Z, [x20]\n" + ".inst 0xc13c1b00 // fmla za.s[x8, 0], { z24.s-z27.s }, z12.s\n" + ".inst 0xc1341b20 // fmla za.s[x8, 0], { z25.s-z28.s }, z4.s\n" + "12:" // Padded: 1 priming loads + "mov x12, #0x0\n" + "add x20, x13, %x[ld_in_row], LSL #2\n" + ".inst 0x25304500 // psel p0.s, p1.s/Z, p8.s[w12]\n" + "ld1w { z25.s }, p0/Z, [x13]\n" + ".inst 0x25704500 // psel p0.s, p1.s/Z, p8.s[w12, #1]\n" + "add x13, x13, %x[ld_in_col], LSL #2\n" + "ld1w { z26.s }, p0/Z, [x20]\n" + ".inst 0x25b04500 // psel p0.s, p1.s/Z, p8.s[w12, #2]\n" + "add x20, x20, %x[ld_in_row], LSL #2\n" + "ld1w { z27.s }, p0/Z, [x20]\n" + ".inst 0x25f04500 // psel p0.s, p1.s/Z, p8.s[w12, #3]\n" + "add x20, x20, %x[ld_in_row], LSL #2\n" + "mov x12, #0x4\n" + "ld1w { z28.s }, p0/Z, [x20]\n" + ".inst 0x25304500 // psel p0.s, p1.s/Z, p8.s[w12]\n" + "add x20, x20, %x[ld_in_row], LSL #2\n" + ".inst 0xc13f1b20 // fmla za.s[x8, 0], { z25.s-z28.s }, z15.s\n" + "ld1w { z29.s }, p0/Z, [x20]\n" + ".inst 0x25704500 // psel p0.s, p1.s/Z, p8.s[w12, #1]\n" + "add x20, x20, %x[ld_in_row], LSL #2\n" + ".inst 0xc13e1b21 // fmla za.s[x8, 1], { z25.s-z28.s }, z14.s\n" + "ld1w { z30.s }, p0/Z, [x20]\n" + ".inst 0xc13d1b40 // fmla za.s[x8, 0], { z26.s-z29.s }, z13.s\n" + ".inst 0xc13c1b41 // fmla za.s[x8, 1], { z26.s-z29.s }, z12.s\n" + ".inst 0xc1351b60 // fmla za.s[x8, 0], { z27.s-z30.s }, z5.s\n" + ".inst 0xc1341b61 // fmla za.s[x8, 1], { z27.s-z30.s }, z4.s\n" + "13:" // Padded: 0 priming loads + "cbz x14, 16f\n" + "mov x12, #0x0\n" + "add x20, x13, %x[ld_in_row], LSL #2\n" + ".inst 0x25304500 // psel p0.s, p1.s/Z, p8.s[w12]\n" + "sub x14, x14, #0x1\n" + "sub x11, x11, #0x1\n" + "cmp x14, x11\n" + "ld1w { z20.s }, p0/Z, [x13]\n" + ".inst 0x25704500 // psel p0.s, p1.s/Z, p8.s[w12, #1]\n" + "csel x21, x14, x11, LT\n" + "add x13, x13, %x[ld_in_col], LSL #2\n" + "sub x11, x11, x21\n" + "ld1w { z21.s }, p0/Z, [x20]\n" + ".inst 0x25b04500 // psel p0.s, p1.s/Z, p8.s[w12, #2]\n" + "add x20, x20, %x[ld_in_row], LSL #2\n" + "ld1w { z22.s }, p0/Z, [x20]\n" + ".inst 0x25f04500 // psel p0.s, p1.s/Z, p8.s[w12, #3]\n" + "add x20, x20, %x[ld_in_row], LSL #2\n" + "mov x12, #0x4\n" + "ld1w { z23.s }, p0/Z, [x20]\n" + ".inst 0x25304500 // psel p0.s, p1.s/Z, p8.s[w12]\n" + "add x20, x20, %x[ld_in_row], LSL #2\n" + "ld1w { z24.s }, p0/Z, [x20]\n" + ".inst 0x25704500 // psel p0.s, p1.s/Z, p8.s[w12, #1]\n" + "add x20, x20, %x[ld_in_row], LSL #2\n" + "ld1w { z25.s }, p0/Z, [x20]\n" + "cbz x21, 15f\n" + "14:" // Padded: Main loop + "mov x12, #0x0\n" + ".inst 0xc13b1a80 // fmla za.s[x8, 0], { z20.s-z23.s }, z11.s\n" + "add x20, x13, %x[ld_in_row], LSL #2\n" + ".inst 0x25304500 // psel p0.s, p1.s/Z, p8.s[w12]\n" + ".inst 0xc13f1a81 // fmla za.s[x8, 1], { z20.s-z23.s }, z15.s\n" + "subs x21, x21, #0x1\n" + ".inst 0xc13e1a82 // fmla za.s[x8, 2], { z20.s-z23.s }, z14.s\n" + "ld1w { z20.s }, p0/Z, [x13]\n" + ".inst 0x25704500 // psel p0.s, p1.s/Z, p8.s[w12, #1]\n" + ".inst 0xc1301aa0 // fmla za.s[x8, 0], { z21.s-z24.s }, z0.s\n" + "add x13, x13, %x[ld_in_col], LSL #2\n" + ".inst 0xc13d1aa1 // fmla za.s[x8, 1], { z21.s-z24.s }, z13.s\n" + ".inst 0xc13c1aa2 // fmla za.s[x8, 2], { z21.s-z24.s }, z12.s\n" + "ld1w { z21.s }, p0/Z, [x20]\n" + ".inst 0x25b04500 // psel p0.s, p1.s/Z, p8.s[w12, #2]\n" + "add x20, x20, %x[ld_in_row], LSL #2\n" + ".inst 0xc1371ac0 // fmla za.s[x8, 0], { z22.s-z25.s }, z7.s\n" + ".inst 0xc1351ac1 // fmla za.s[x8, 1], { z22.s-z25.s }, z5.s\n" + ".inst 0xc1341ac2 // fmla za.s[x8, 2], { z22.s-z25.s }, z4.s\n" + "ld1w { z22.s }, p0/Z, [x20]\n" + ".inst 0x25f04500 // psel p0.s, p1.s/Z, p8.s[w12, #3]\n" + "add x20, x20, %x[ld_in_row], LSL #2\n" + "mov x12, #0x4\n" + ".inst 0xc0060c1c // mova { z28.d-z31.d }, za.d[x8, #0]\n" + "add x8, x8, #0x1\n" + "ld1w { z23.s }, p0/Z, [x20]\n" + ".inst 0x25304500 // psel p0.s, p1.s/Z, p8.s[w12]\n" + "add x20, x20, %x[ld_in_row], LSL #2\n" + ".inst 0xc0040e02 // mova za.d[x8, #2], { z16.d-z19.d }\n" + "ld1w { z24.s }, p0/Z, [x20]\n" + ".inst 0x25704500 // psel p0.s, p1.s/Z, p8.s[w12, #1]\n" + "add x20, x20, %x[ld_in_row], LSL #2\n" + ".inst 0xc1a9c87c // fclamp { z28.s-z31.s }, z3.s, z9.s\n" + "ld1w { z25.s }, p0/Z, [x20]\n" + "st1w { z28.s }, p1, [x9]\n" + "add x9, x9, x27, LSL #2\n" + "st1w { z29.s }, p1, [x28]\n" + "add x28, x28, x26, LSL #2\n" + "st1w { z30.s }, p1, [x25]\n" + "add x25, x25, x23, LSL #2\n" + "st1w { z31.s }, p1, [x24]\n" + "add x24, x24, x22, LSL #2\n" + "bgt 14b\n" + "15:" // Main loop tail + ".inst 0xc13b1a80 // fmla za.s[x8, 0], { z20.s-z23.s }, z11.s\n" + ".inst 0xc13f1a81 // fmla za.s[x8, 1], { z20.s-z23.s }, z15.s\n" + ".inst 0xc13e1a82 // fmla za.s[x8, 2], { z20.s-z23.s }, z14.s\n" + ".inst 0xc1301aa0 // fmla za.s[x8, 0], { z21.s-z24.s }, z0.s\n" + ".inst 0xc13d1aa1 // fmla za.s[x8, 1], { z21.s-z24.s }, z13.s\n" + ".inst 0xc13c1aa2 // fmla za.s[x8, 2], { z21.s-z24.s }, z12.s\n" + ".inst 0xc1371ac0 // fmla za.s[x8, 0], { z22.s-z25.s }, z7.s\n" + ".inst 0xc1351ac1 // fmla za.s[x8, 1], { z22.s-z25.s }, z5.s\n" + ".inst 0xc1341ac2 // fmla za.s[x8, 2], { z22.s-z25.s }, z4.s\n" + ".inst 0xc0060c14 // mova { z20.d-z23.d }, za.d[x8, #0]\n" + "add x8, x8, #0x1\n" + ".inst 0xc0040e02 // mova za.d[x8, #2], { z16.d-z19.d }\n" + ".inst 0xc1a9c874 // fclamp { z20.s-z23.s }, z3.s, z9.s\n" + "st1w { z20.s }, p1, [x9]\n" + "add x9, x9, x27, LSL #2\n" + "st1w { z21.s }, p1, [x28]\n" + "add x28, x28, x26, LSL #2\n" + "st1w { z22.s }, p1, [x25]\n" + "add x25, x25, x23, LSL #2\n" + "st1w { z23.s }, p1, [x24]\n" + "add x24, x24, x22, LSL #2\n" + "16:" // Main loop skip tail + "cbz x11, 18f\n" + "17:" // Right padding loop + ".inst 0xc0060c04 // mova { z4.d-z7.d }, za.d[x8, #0]\n" + "add x8, x8, #0x1\n" + "subs x11, x11, #0x1\n" + ".inst 0xc0040e02 // mova za.d[x8, #2], { z16.d-z19.d }\n" + ".inst 0xc1a9c864 // fclamp { z4.s-z7.s }, z3.s, z9.s\n" + "st1w { z4.s }, p1, [x9]\n" + "add x9, x9, x27, LSL #2\n" + "st1w { z5.s }, p1, [x28]\n" + "add x28, x28, x26, LSL #2\n" + "st1w { z6.s }, p1, [x25]\n" + "add x25, x25, x23, LSL #2\n" + "st1w { z7.s }, p1, [x24]\n" + "add x24, x24, x22, LSL #2\n" + "bgt 17b\n" + "18:" // End + "ldr x20, [%x[args], %[offsetof_Args_weights]]\n" + "incw x15\n" + "whilelt p1.s, x15, x16\n" + "incb x20, ALL, MUL #9\n" + "str x20, [%x[args], %[offsetof_Args_weights]]\n" + "ldr x21, [%x[args], %[offsetof_Args_ld_in_vl]]\n" + "ldr x20, [%x[args], %[offsetof_Args_inptr]]\n" + "add x20, x20, x21, LSL #2\n" + "str x20, [%x[args], %[offsetof_Args_inptr]]\n" + "ldr x25, [%x[args], %[offsetof_Args_outptrs]]\n" + "ldr x24, [%x[args], %[offsetof_Args_ld_out_vls]]\n" + "ldp x23, x22, [x25, #0x0]\n" + "ldp x21, x20, [x24, #0x0]\n" + "add x23, x23, x21, LSL #2\n" + "add x22, x22, x20, LSL #2\n" + "stp x23, x22, [x25, #0x0]\n" + "ldp x23, x22, [x25, #0x10]\n" + "ldp x21, x20, [x24, #0x10]\n" + "add x23, x23, x21, LSL #2\n" + "add x22, x22, x20, LSL #2\n" + "stp x23, x22, [x25, #0x10]\n" + "b.any 1b\n" + ".inst 0xd503467f // SMSTOP\n" + : + : [args] "r"(&args), [ld_in_col] "r"(ld_in_col), [ld_in_row] "r"(ld_in_row), + [offsetof_Args_bias] "I"(offsetof(Args, bias)), [offsetof_Args_clamp_max] "I"(offsetof(Args, clamp_max)), + [offsetof_Args_clamp_min] "I"(offsetof(Args, clamp_min)), + [offsetof_Args_current_channel] "I"(offsetof(Args, current_channel)), + [offsetof_Args_inptr] "I"(offsetof(Args, inptr)), [offsetof_Args_input_cols] "I"(offsetof(Args, input_cols)), + [offsetof_Args_ld_in_vl] "I"(offsetof(Args, ld_in_vl)), + [offsetof_Args_ld_out_cols] "I"(offsetof(Args, ld_out_cols)), + [offsetof_Args_ld_out_vls] "I"(offsetof(Args, ld_out_vls)), + [offsetof_Args_n_channels] "I"(offsetof(Args, n_channels)), + [offsetof_Args_outptrs] "I"(offsetof(Args, outptrs)), + [offsetof_Args_output_cols] "I"(offsetof(Args, output_cols)), + [offsetof_Args_pad_bottom] "I"(offsetof(Args, pad_bottom)), + [offsetof_Args_pad_left] "I"(offsetof(Args, pad_left)), [offsetof_Args_pad_top] "I"(offsetof(Args, pad_top)), + [offsetof_Args_weights] "I"(offsetof(Args, weights)) + : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", "p7", + "p8", "p9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x20", "x21", "x22", "x23", "x24", "x25", + "x26", "x27", "x28", "x7", "x8", "x9", "z0", "z1", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", + "z18", "z19", "z2", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z3", "z30", "z31", + "z4", "z5", "z6", "z7", "z8", "z9"); +} +#endif // Architectural features check. diff --git a/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.h b/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.h new file mode 100644 index 00000000..0fc65cb7 --- /dev/null +++ b/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.h @@ -0,0 +1,58 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +#include + +/// @return output size in bytes. +size_t kai_get_dst_size_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla( + const size_t out_height, const size_t out_width, const size_t num_channels); + +/// @return Maximum number of rows of output data produced by this kernel. +size_t kai_get_m_step_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla(); + +/// @return Height of the filter used by this kernel. +size_t kai_get_filter_height_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla(); + +/// @return Height of the filter used by this kernel. +size_t kai_get_filter_width_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla(); + +/// @param[in] out_row_idx the row index of the output matrix +/// @param[in] stride_out_row Output row stride in bytes +/// @return offset to element in output/destination matrix. +size_t kai_get_dst_offset_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla(size_t out_row_idx, size_t stride_out_row); + +/// Runs a depthwise convolution operation followed by a clamp operations +/// +/// @param[in] inptr Pointer to the start of valid input row to be processed. +/// @param[in] stride_in_row Row stride of input tensor in bytes. +/// Same as input_w * input_channel when row_dilation = 1 +/// @param[in] stride_in_col Column stride within the input tensor, in bytes. +/// @param[in] pad_top Number of zero pad rows that precede the first valid input row. +/// @param[in] pad_left Number of zero pad columns on the left edge. +/// @param[in] valid_input_rows Count of real input rows available from the start row (identifies bottom padding). +/// @param[in] weights Pointer to packed weights. +/// @param[in] bias Optional pointer to bias array (one float per channel). +/// @param[in] outptr_start Pointer to the first element of the top output row for this tile (four rows written) +/// @param[in] stride_out_col Output column stride in bytes. +/// @param[in] stride_out_row Output row stride in bytes. +/// @param[in] valid_out_rows Number of rows to output to (1-4). +/// @param[in] act_min Lower clamp bound applied to every output value. +/// @param[in] act_max Upper clamp bound applied to every output value. +/// @param[in] pad_value Fill value for padding. This kernel only supports 0. +void kai_run_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla( + const void* inptr, size_t stride_in_row, size_t stride_in_col, unsigned int pad_top, unsigned int pad_left, + unsigned int valid_input_rows, const void* weights, const void* bias, void* outptr_start, size_t stride_out_col, + size_t stride_out_row, unsigned int valid_out_rows, float act_min, float act_max, const float pad_value); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/kai/ukernels/dw_conv/pack/kai_rhs_dw_conv_pack_x32p1vl_x32_sme.c b/kai/ukernels/dw_conv/pack/kai_rhs_dw_conv_pack_x32p1vl_x32_sme.c new file mode 100644 index 00000000..24dc36e3 --- /dev/null +++ b/kai/ukernels/dw_conv/pack/kai_rhs_dw_conv_pack_x32p1vl_x32_sme.c @@ -0,0 +1,47 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "kai/kai_common.h" + +size_t kai_rhs_get_dst_size_dw_conv_pack_x32p1vl_x32_sme( + size_t filter_height, size_t filter_width, size_t num_channels) { + const size_t depth_elements = kai_roundup(num_channels, kai_get_sme_vector_length_u32()); + return depth_elements * filter_height * filter_width * sizeof(float); +} + +void kai_run_rhs_dw_conv_pack_x32p1vl_x32_sme( + void* rhs, void* rhs_packed, size_t filter_height, size_t filter_width, size_t height, size_t width, + size_t num_channels) { + KAI_ASSUME(rhs != NULL); + KAI_ASSUME(rhs_packed != NULL); + KAI_UNUSED(height); + KAI_UNUSED(width); + + // Cast the pointers to byte sizes + const uint8_t* src = (uint8_t*)(rhs); + uint8_t* dst = (uint8_t*)(rhs_packed); + + const size_t vl = kai_get_sme_vector_length_u32(); + const size_t element_size = sizeof(float); + + // Create padding buffer to copy from. + const uint8_t pad_row[KAI_SME_VEC_LENGTH_MAX_BYTES] = {0}; + + for (size_t n = 0; n < num_channels; n += vl) { + // Copy each of the weights in turn + const size_t count = (vl < (num_channels - n)) ? vl : (num_channels - n); + for (size_t h = 0; h < filter_height; h++) { + for (size_t w = 0; w < filter_width; w++) { + uint8_t* src_ptr = pad_row; + if (h < filter_height && w < filter_width) { + src_ptr = src + ((h * filter_width + w) * num_channels + n) * element_size; + } + memcpy(dst, src_ptr, count * element_size); + dst += (vl * element_size); // move ptr. + } + } + } +} diff --git a/kai/ukernels/dw_conv/pack/kai_rhs_dw_conv_pack_x32p1vl_x32_sme.h b/kai/ukernels/dw_conv/pack/kai_rhs_dw_conv_pack_x32p1vl_x32_sme.h new file mode 100644 index 00000000..583b4872 --- /dev/null +++ b/kai/ukernels/dw_conv/pack/kai_rhs_dw_conv_pack_x32p1vl_x32_sme.h @@ -0,0 +1,43 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Get the size in bytes of the packed data buffer. +/// @param[in] filter_height filter height of filter being used in convolution. +/// @param[in] filter_width filter width of filter being used in convolution. +/// @param[in] num_channels Number of channels in input matrix. +/// @return The size in bytes of packed data buffer. +size_t kai_rhs_get_dst_size_dw_conv_pack_x32p1vl_x32_sme( + size_t filter_height, size_t filter_width, size_t num_channels); + +/// Runs the packing function for the depthwise convolution kernel. +/// +/// NOTE: filter_height/filter_width is seperate from height/width of weights to allow for padding when using weights +/// shapes different to kernel conv filter size. These should be the same in typical usecases. +/// +/// NOTE: The API below is experimental and may change in the future, particularly the parameters "height" and "width" +/// +/// @param[in] rhs Rhs matrix data buffer +/// @param[out] rhs_packed Packed data matrix buffer +/// @param[in] filter_height filter height of filter being used in convolution. +/// @param[in] filter_width filter width of filter being used in convolution. +/// @param[in] height Height dimension of rhs matrix. (Typically equivalent to filter_height) +/// @param[in] width Width dimension of rhs matrix. (Typically equivalent to filter_width) +/// @param[in] num_channels Number of channels in input matrix. +void kai_run_rhs_dw_conv_pack_x32p1vl_x32_sme( + void* rhs, void* rhs_packed, size_t filter_height, size_t filter_width, size_t height, size_t width, + size_t num_channels); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/test/reference/depthwise_conv.cpp b/test/reference/depthwise_conv.cpp new file mode 100644 index 00000000..14e1b26e --- /dev/null +++ b/test/reference/depthwise_conv.cpp @@ -0,0 +1,76 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/reference/depthwise_conv.hpp" + +#include "kai/kai_common.h" + +namespace kai::test { + +struct Padding2D { + size_t left; + size_t right; + size_t top; + size_t bottom; +}; + +template +Buffer depthwise_reference( + const size_t batches, const size_t in_height, const size_t in_width, const size_t channels, + const size_t filter_height, const size_t filter_width, const void* feature_map, const void* weights, + const void* bias, float clamp_min, float clamp_max) { + // Calculate output dims (Padding = Valid). + const size_t out_height = in_height + 1 - filter_height; + const size_t out_width = in_width + 1 - filter_width; + const size_t out_size = out_height * out_width * batches * channels; + + // We accumulate in FP32 and clamp and cast to return type later. + std::vector acc(out_size, 0.0f); + Buffer dst(out_size * size_in_bits / 8); + + for (size_t b = 0; b < batches; ++b) { + for (size_t out_h = 0; out_h < out_height; ++out_h) { + for (size_t out_w = 0; out_w < out_width; ++out_w) { + // Apply filter to feature map. + for (size_t ic = 0; ic < channels; ++ic) { + for (size_t kernel_h = 0; kernel_h < filter_height; ++kernel_h) { + if (in_height <= (out_h + kernel_h)) continue; + for (size_t kernel_w = 0; kernel_w < filter_width; ++kernel_w) { + if (in_width <= (out_w + kernel_w)) continue; + auto in_idx = + ((b * in_height + (out_h + kernel_h)) * in_width + (out_w + kernel_w)) * channels + ic; + auto weights_idx = ((kernel_h * filter_width) + kernel_w) * channels + ic; + auto out_idx = ((b * out_height + out_h) * out_width + out_w) * channels + ic; + + auto wei_value = read_array(weights, weights_idx); + auto in_value = read_array(feature_map, in_idx); + + // Perform actual accumulation and store in output vector + acc[out_idx] += in_value * wei_value; + } + } + } + + // Apply bias. + for (size_t ic = 0; ic < channels; ++ic) { + auto out_idx = ((b * out_height + out_h) * out_width + out_w) * channels; + acc[out_idx + ic] += read_array(bias, ic); + } + } + } + } + + // Apply clamping to accumulator, cast to FP16 and store in output vector at the same idx. + for (size_t i = 0; i < out_size; i++) { + acc[i] = (clamp_min > acc[i]) ? clamp_min : acc[i]; + acc[i] = (clamp_max < acc[i]) ? clamp_max : acc[i]; + write_array(dst.data(), i, acc[i]); + } + + return dst; +} + +} // namespace kai::test diff --git a/test/reference/depthwise_conv.hpp b/test/reference/depthwise_conv.hpp new file mode 100644 index 00000000..f3982e86 --- /dev/null +++ b/test/reference/depthwise_conv.hpp @@ -0,0 +1,39 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once +#include + +#include "test/common/buffer.hpp" +#include "test/common/data_type.hpp" +#include "test/common/memory.hpp" + +namespace kai::test { + +/// Depthwise Convolution function +/// +/// @tparam T Data type. +/// +/// @param[in] batches Batch dimension of feature map. +/// @param[in] in_height height of feature map. +/// @param[in] in_width width of feature map. +/// @param[in] channels Number of channels in feature map. +/// @param[in] filter_height Height dimension in filter. +/// @param[in] filter_width Width of convolution filter. +/// @param[in] feature_map Ptr to start of feature map. +/// @param[in] weights Ptr to start of weights buffer/tensor. +/// @param[in] bias Ptr to start of bias buffer. +/// @param[in] clamp_min float value to clamp output to (lower bound). +/// @param[in] clamp_max float value to clamp output to (upper bound). +/// +/// @return The result data buffer. +template +Buffer depthwise_reference( + const size_t batches, const size_t in_height, const size_t in_width, const size_t channels, + const size_t filter_height, const size_t filter_width, const void* feature_map, const void* weights, + const void* bias, float clamp_min, float clamp_max); + +} // namespace kai::test -- GitLab From 6369ba733fe037c1ea772612b599bf64daaba5f2 Mon Sep 17 00:00:00 2001 From: Felix Thomasmathibalan Date: Fri, 18 Jul 2025 15:21:48 +0100 Subject: [PATCH 2/6] Fix review comments Fix warnings in building example Added comments Signed-off-by: Felix Thomasmathibalan --- .../dw_conv_f32_f32_f32p_planar_sme2/dconv.cpp | 15 ++++++--------- .../pack/kai_rhs_dw_conv_pack_x32p1vl_x32_sme.c | 2 +- .../pack/kai_rhs_dw_conv_pack_x32p1vl_x32_sme.h | 4 ++-- 3 files changed, 9 insertions(+), 12 deletions(-) diff --git a/examples/dw_conv_f32_f32_f32p_planar_sme2/dconv.cpp b/examples/dw_conv_f32_f32_f32p_planar_sme2/dconv.cpp index 61d438f8..5b0325f1 100644 --- a/examples/dw_conv_f32_f32_f32p_planar_sme2/dconv.cpp +++ b/examples/dw_conv_f32_f32_f32p_planar_sme2/dconv.cpp @@ -185,8 +185,6 @@ void depthwise_reference( } // namespace int main() { - // FIXED : We should not care about batches for this example, those are just loops. - // All other variables below are kernel specific. const int batches = 1; enum class pad_mode { SAME, VALID }; @@ -225,7 +223,7 @@ int main() { fill_matrix(in_shape.size(), input, 0.01f); fill_matrix(wei_shape.size(), weights, 0.02f); - fill_matrix_uniform(bias_shape.size(), bias, 0.f); + fill_matrix_uniform(bias_shape.size(), bias, 1.f); // For testing using Python. #ifdef KAI_DEBUG @@ -275,15 +273,11 @@ int main() { "\n Weights Packed : ", weights_packed.data()); #endif // ------------------------------------------------- - // 3. Kernel handles 4 rows at a time. - // We must adjust output pointer to the start of each 4 rows to be output and the input pointer - // passed to kernel. Kernel also expects 4 element array of strides and vec_lengths which are - // determined by if - // the specified output rows are valid or not. If not, 0 is passed in place of value. + // 3. Kernel takes in 6 rows of input and generates + // rows of output across all channels at a time. // ------------------------------------------------- constexpr size_t rows_handled = 4; // no of rows kernel handles each time. for (size_t out_row = 0; out_row < out_shape.h; out_row += rows_handled) { - // printf("\nCalculating output rows : {%zu} - {%zu}\n", out_row, out_row + rows_handled); // Variables below used to calculate start of input pointer. const int start_in_row = out_row - padding.top; @@ -303,6 +297,9 @@ int main() { auto outptr = out.data() + (out_row * out_row_stride_elements); // NOTE: Kernel expects strides to be passed as bytes. + // f32_f32_f32p1vl -> f32 output, f32 LHS, packed F32 rhs as 1VL blocks. + // 3x3_s : 3x3 filter with stride 1 + // 4xc : 4 output channels across the plane(c) is produced. kai_run_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla( inptr, in_row_stride_elements * sizeof(float), channels * sizeof(float), pad_top, padding.left, valid_input_rows, weights_packed.data(), bias.data(), outptr, diff --git a/kai/ukernels/dw_conv/pack/kai_rhs_dw_conv_pack_x32p1vl_x32_sme.c b/kai/ukernels/dw_conv/pack/kai_rhs_dw_conv_pack_x32p1vl_x32_sme.c index 24dc36e3..3777872a 100644 --- a/kai/ukernels/dw_conv/pack/kai_rhs_dw_conv_pack_x32p1vl_x32_sme.c +++ b/kai/ukernels/dw_conv/pack/kai_rhs_dw_conv_pack_x32p1vl_x32_sme.c @@ -35,7 +35,7 @@ void kai_run_rhs_dw_conv_pack_x32p1vl_x32_sme( const size_t count = (vl < (num_channels - n)) ? vl : (num_channels - n); for (size_t h = 0; h < filter_height; h++) { for (size_t w = 0; w < filter_width; w++) { - uint8_t* src_ptr = pad_row; + const uint8_t* src_ptr = pad_row; if (h < filter_height && w < filter_width) { src_ptr = src + ((h * filter_width + w) * num_channels + n) * element_size; } diff --git a/kai/ukernels/dw_conv/pack/kai_rhs_dw_conv_pack_x32p1vl_x32_sme.h b/kai/ukernels/dw_conv/pack/kai_rhs_dw_conv_pack_x32p1vl_x32_sme.h index 583b4872..fa3a04ea 100644 --- a/kai/ukernels/dw_conv/pack/kai_rhs_dw_conv_pack_x32p1vl_x32_sme.h +++ b/kai/ukernels/dw_conv/pack/kai_rhs_dw_conv_pack_x32p1vl_x32_sme.h @@ -31,8 +31,8 @@ size_t kai_rhs_get_dst_size_dw_conv_pack_x32p1vl_x32_sme( /// @param[out] rhs_packed Packed data matrix buffer /// @param[in] filter_height filter height of filter being used in convolution. /// @param[in] filter_width filter width of filter being used in convolution. -/// @param[in] height Height dimension of rhs matrix. (Typically equivalent to filter_height) -/// @param[in] width Width dimension of rhs matrix. (Typically equivalent to filter_width) +/// @param[in] height Height dimension of rhs matrix. Unused. (Typically equivalent to filter_height) +/// @param[in] width Width dimension of rhs matrix. Unused. (Typically equivalent to filter_width) /// @param[in] num_channels Number of channels in input matrix. void kai_run_rhs_dw_conv_pack_x32p1vl_x32_sme( void* rhs, void* rhs_packed, size_t filter_height, size_t filter_width, size_t height, size_t width, -- GitLab From 9d95d87323984ee221ac087adf5cf2edcf349bf5 Mon Sep 17 00:00:00 2001 From: Mohammed Suhail Munshi Date: Fri, 18 Jul 2025 16:25:08 +0100 Subject: [PATCH 3/6] Fix Build issues. Signed-off-by: Mohammed Suhail Munshi --- CMakeLists.txt | 2 +- .../CMakeLists.txt | 17 ++++----------- .../dconv.cpp | 1 - ...2_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.c | 10 ++++----- ...2_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.h | 10 ++++----- .../kai_rhs_dw_conv_pack_x32p1vl_x32_sme.c | 21 ++++++++----------- 6 files changed, 24 insertions(+), 37 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 23d53f19..6a237ce9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -438,7 +438,7 @@ if(KLEIDIAI_BUILD_TESTS) target_compile_options(kleidiai_test_framework PUBLIC ${KLEIDIAI_WARNING_FLAGS} - PUBLIC $<$>:-march=armv8-a> + PUBLIC $<$>:-march=armv8-a${KLEIDIAI_INTERNAL_EXTRA_ARCH}> ) if(MSVC) diff --git a/examples/dw_conv_f32_f32_f32p_planar_sme2/CMakeLists.txt b/examples/dw_conv_f32_f32_f32p_planar_sme2/CMakeLists.txt index d271b86a..86dbb2d9 100644 --- a/examples/dw_conv_f32_f32_f32p_planar_sme2/CMakeLists.txt +++ b/examples/dw_conv_f32_f32_f32p_planar_sme2/CMakeLists.txt @@ -6,34 +6,25 @@ cmake_minimum_required(VERSION 3.16) -project(main) +project(dw_conv_f32_f32_f32p_planar_sme2) set(CMAKE_CXX_STANDARD 17) set(KAI_PATH ${CMAKE_CURRENT_SOURCE_DIR}/../../) set(KAI_BUILD ${KAI_PATH}/build) -set(CMAKE_EXPORT_COMPILE_COMMANDS ON) - -# Create a tiny interface target that only carries the desired ISA flags -add_library(sve_flags INTERFACE) -target_compile_options(sve_flags INTERFACE - -march=armv8.2-a+sve+sve2 - -fno-tree-vectorize) include_directories(${KAI_PATH}) # Files requires to build the executable add_executable( - main dconv.cpp + dw_conv_f32_f32_f32p_planar_sme2 dconv.cpp "${KAI_PATH}/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.c" "${KAI_PATH}/kai/ukernels/dw_conv/pack/kai_rhs_dw_conv_pack_x32p1vl_x32_sme.c" ) -target_compile_options(main +target_compile_options(dw_conv_f32_f32_f32p_planar_sme2 PRIVATE "-march=armv8.2-a+sve+sve2;-fno-tree-vectorize" ) -target_link_libraries(main PRIVATE sve_flags) - -target_compile_definitions(main +target_compile_definitions(dw_conv_f32_f32_f32p_planar_sme2 PRIVATE $<$:KAI_DEBUG> ) diff --git a/examples/dw_conv_f32_f32_f32p_planar_sme2/dconv.cpp b/examples/dw_conv_f32_f32_f32p_planar_sme2/dconv.cpp index 5b0325f1..2203b794 100644 --- a/examples/dw_conv_f32_f32_f32p_planar_sme2/dconv.cpp +++ b/examples/dw_conv_f32_f32_f32p_planar_sme2/dconv.cpp @@ -278,7 +278,6 @@ int main() { // ------------------------------------------------- constexpr size_t rows_handled = 4; // no of rows kernel handles each time. for (size_t out_row = 0; out_row < out_shape.h; out_row += rows_handled) { - // Variables below used to calculate start of input pointer. const int start_in_row = out_row - padding.top; const size_t pad_top = (start_in_row < 0) ? (-start_in_row) : 0; diff --git a/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.c b/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.c index 605f47e9..3c50e885 100644 --- a/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.c +++ b/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.c @@ -27,15 +27,15 @@ size_t kai_get_dst_size_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla( return out_height * out_width * num_channels * sizeof(float); } -size_t kai_get_m_step_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla() { +size_t kai_get_m_step_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla(void) { return kai_mr; } -size_t kai_get_filter_height_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla() { +size_t kai_get_filter_height_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla(void) { return kai_kh; } -size_t kai_get_filter_width_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla() { +size_t kai_get_filter_width_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla(void) { return kai_kw; } @@ -48,12 +48,12 @@ size_t kai_get_dst_offset_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla( void kai_run_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla( const void* inptr, size_t stride_in_row, size_t stride_in_col, unsigned int pad_top, unsigned int pad_left, unsigned int valid_input_rows, const void* weights, const void* bias, void* outptr_start, size_t stride_out_col, - size_t stride_out_row, unsigned int valid_out_rows, float act_min, float act_max, const float pad_value) { + size_t stride_out_row, unsigned int valid_out_rows, float act_min, float act_max, float pad_value) { KAI_ASSUME(inptr != NULL); KAI_ASSUME(weights != NULL); KAI_ASSUME(outptr_start != NULL); KAI_ASSUME(valid_out_rows != 0); - KAI_ASSUME(pad_value == 0.0f); + KAI_ASSUME(pad_value == 0.0F); typedef struct { const void* inptr; diff --git a/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.h b/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.h index 0fc65cb7..e1b00fcd 100644 --- a/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.h +++ b/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.h @@ -14,16 +14,16 @@ extern "C" { /// @return output size in bytes. size_t kai_get_dst_size_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla( - const size_t out_height, const size_t out_width, const size_t num_channels); + size_t out_height, size_t out_width, size_t num_channels); /// @return Maximum number of rows of output data produced by this kernel. -size_t kai_get_m_step_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla(); +size_t kai_get_m_step_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla(void); /// @return Height of the filter used by this kernel. -size_t kai_get_filter_height_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla(); +size_t kai_get_filter_height_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla(void); /// @return Height of the filter used by this kernel. -size_t kai_get_filter_width_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla(); +size_t kai_get_filter_width_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla(void); /// @param[in] out_row_idx the row index of the output matrix /// @param[in] stride_out_row Output row stride in bytes @@ -51,7 +51,7 @@ size_t kai_get_dst_offset_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla(siz void kai_run_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla( const void* inptr, size_t stride_in_row, size_t stride_in_col, unsigned int pad_top, unsigned int pad_left, unsigned int valid_input_rows, const void* weights, const void* bias, void* outptr_start, size_t stride_out_col, - size_t stride_out_row, unsigned int valid_out_rows, float act_min, float act_max, const float pad_value); + size_t stride_out_row, unsigned int valid_out_rows, float act_min, float act_max, float pad_value); #ifdef __cplusplus } // extern "C" diff --git a/kai/ukernels/dw_conv/pack/kai_rhs_dw_conv_pack_x32p1vl_x32_sme.c b/kai/ukernels/dw_conv/pack/kai_rhs_dw_conv_pack_x32p1vl_x32_sme.c index 3777872a..b870a8fa 100644 --- a/kai/ukernels/dw_conv/pack/kai_rhs_dw_conv_pack_x32p1vl_x32_sme.c +++ b/kai/ukernels/dw_conv/pack/kai_rhs_dw_conv_pack_x32p1vl_x32_sme.c @@ -4,6 +4,11 @@ // SPDX-License-Identifier: Apache-2.0 // +#include "kai/ukernels/dw_conv/pack/kai_rhs_dw_conv_pack_x32p1vl_x32_sme.h" + +#include +#include + #include "kai/kai_common.h" size_t kai_rhs_get_dst_size_dw_conv_pack_x32p1vl_x32_sme( @@ -27,21 +32,13 @@ void kai_run_rhs_dw_conv_pack_x32p1vl_x32_sme( const size_t vl = kai_get_sme_vector_length_u32(); const size_t element_size = sizeof(float); - // Create padding buffer to copy from. - const uint8_t pad_row[KAI_SME_VEC_LENGTH_MAX_BYTES] = {0}; - for (size_t n = 0; n < num_channels; n += vl) { // Copy each of the weights in turn const size_t count = (vl < (num_channels - n)) ? vl : (num_channels - n); - for (size_t h = 0; h < filter_height; h++) { - for (size_t w = 0; w < filter_width; w++) { - const uint8_t* src_ptr = pad_row; - if (h < filter_height && w < filter_width) { - src_ptr = src + ((h * filter_width + w) * num_channels + n) * element_size; - } - memcpy(dst, src_ptr, count * element_size); - dst += (vl * element_size); // move ptr. - } + for (size_t idx = 0; idx < filter_height * filter_width; idx++) { + const uint8_t* src_ptr = src + ((idx * num_channels + n) * element_size); + memcpy(dst, src_ptr, count * element_size); + dst += (vl * element_size); // move ptr. } } } -- GitLab From 7d22780e5561a33fa0522cdeb5b7805cf7fcb81b Mon Sep 17 00:00:00 2001 From: Mohammed Suhail Munshi Date: Wed, 23 Jul 2025 16:07:09 +0100 Subject: [PATCH 4/6] Add depthwise kernel tests Signed-off-by: Mohammed Suhail Munshi --- CMakeLists.txt | 2 + .../kai_dw_conv_f32_f32_f32p_interface.h | 43 +++ test/reference/depthwise_conv.cpp | 51 ++- test/reference/depthwise_conv.hpp | 29 +- test/tests/depthwise_planar_test.cpp | 305 ++++++++++++++++++ 5 files changed, 402 insertions(+), 28 deletions(-) create mode 100644 kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p_interface.h create mode 100644 test/tests/depthwise_planar_test.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 6a237ce9..0b60fd9d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -320,6 +320,7 @@ set(KLEIDIAI_FILES_SME2 kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.c + kai/ukernels/dw_conv/pack/kai_rhs_dw_conv_pack_x32p1vl_x32_sme.c kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c ) @@ -469,6 +470,7 @@ if(KLEIDIAI_BUILD_TESTS) test/tests/buffer_test.cpp test/tests/float16_test.cpp test/tests/imatmul_test.cpp + test/tests/depthwise_planar_test.cpp test/tests/matmul_clamp_bf16_qai8dxp_qsi4cxp_test.cpp test/tests/matmul_clamp_f16_bf16p_bf16p_test.cpp test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp diff --git a/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p_interface.h b/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p_interface.h new file mode 100644 index 00000000..ba860d26 --- /dev/null +++ b/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p_interface.h @@ -0,0 +1,43 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// All micro-kernels variants of the same type share the same interfaces +// In this case, the micro-kernel type is: dw_conv_f32_f32_f32p_planar +// NOTE: +// - get_n_step is not provided as n-step is not relevant in planar kernels. +// - get_lhs_packed_offset is not provided as the lhs is not packed with planar kernels. +// - get_rhs_packed_offset is not provided as rhs offset is not relevant with planar kernels. + +/// Micro-kernel helper functions ("get" methods) +typedef size_t (*kai_dw_conv_f32_f32_f32p_planar_get_m_step_func_t)(void); +typedef size_t (*kai_dw_conv_f32_f32_f32p_planar_get_dst_offset_func_t)(size_t out_row_idx, size_t stride_out_row); +typedef size_t (*kai_dw_conv_f32_f32_f32p_planar_get_dst_size_func_t)( + size_t out_height, size_t out_width, size_t num_channels); + +/// Micro-kernel core function ("run" method) +typedef void (*kai_dw_conv_f32_f32_f32p_planar_run_dw_conv_func_t)( + const void* inptr, size_t stride_in_row, size_t stride_in_col, unsigned int pad_top, unsigned int pad_left, + unsigned int valid_input_rows, const void* weights, const void* bias, void* outptr_start, size_t stride_out_col, + size_t stride_out_row, unsigned int valid_out_rows, float act_min, float act_max, float pad_value); + +/// Micro-kernel interface +struct kai_dw_conv_f32_f32_f32p_planar_ukernel { + kai_dw_conv_f32_f32_f32p_planar_get_m_step_func_t get_m_step; + kai_dw_conv_f32_f32_f32p_planar_get_dst_offset_func_t get_dst_offset; + kai_dw_conv_f32_f32_f32p_planar_get_dst_size_func_t get_dst_size; + kai_dw_conv_f32_f32_f32p_planar_run_dw_conv_func_t run_dw_conv; +}; + +#ifdef __cplusplus +} +#endif diff --git a/test/reference/depthwise_conv.cpp b/test/reference/depthwise_conv.cpp index 14e1b26e..9bfbad9b 100644 --- a/test/reference/depthwise_conv.cpp +++ b/test/reference/depthwise_conv.cpp @@ -6,25 +6,24 @@ #include "test/reference/depthwise_conv.hpp" +#include + #include "kai/kai_common.h" namespace kai::test { -struct Padding2D { - size_t left; - size_t right; - size_t top; - size_t bottom; +void PrintTo(const Padding2D& pad, std::ostream* os) { + *os << "__PAD_" << pad.left << "_" << pad.right << "_" << pad.bottom << "_" << pad.top << "_"; }; template Buffer depthwise_reference( const size_t batches, const size_t in_height, const size_t in_width, const size_t channels, const size_t filter_height, const size_t filter_width, const void* feature_map, const void* weights, - const void* bias, float clamp_min, float clamp_max) { + const void* bias, const Padding2D& pad) { // Calculate output dims (Padding = Valid). - const size_t out_height = in_height + 1 - filter_height; - const size_t out_width = in_width + 1 - filter_width; + const size_t out_height = (in_height + pad.top + pad.bottom + 1 - filter_height); + const size_t out_width = in_width + pad.left + pad.right + 1 - filter_width; const size_t out_size = out_height * out_width * batches * channels; // We accumulate in FP32 and clamp and cast to return type later. @@ -34,42 +33,40 @@ Buffer depthwise_reference( for (size_t b = 0; b < batches; ++b) { for (size_t out_h = 0; out_h < out_height; ++out_h) { for (size_t out_w = 0; out_w < out_width; ++out_w) { + const size_t out_base = ((b * out_height + out_h) * out_width + out_w) * channels; + // Apply filter to feature map. for (size_t ic = 0; ic < channels; ++ic) { + float sum = 0.0f; + for (size_t kernel_h = 0; kernel_h < filter_height; ++kernel_h) { - if (in_height <= (out_h + kernel_h)) continue; + // Determine if input height bounds. If not, then this is padding. + const int in_y = static_cast(out_h + kernel_h) - static_cast(pad.top); + if (in_y < 0 || in_height <= static_cast(in_y)) continue; + for (size_t kernel_w = 0; kernel_w < filter_width; ++kernel_w) { - if (in_width <= (out_w + kernel_w)) continue; - auto in_idx = - ((b * in_height + (out_h + kernel_h)) * in_width + (out_w + kernel_w)) * channels + ic; + // Determine if in input width bounds, if not this is padding. + const int in_x = static_cast(out_w + kernel_w) - static_cast(pad.left); + if (in_x < 0 || in_width <= static_cast(in_x)) continue; + + auto in_idx = ((b * in_height + in_y) * in_width + in_x) * channels + ic; auto weights_idx = ((kernel_h * filter_width) + kernel_w) * channels + ic; - auto out_idx = ((b * out_height + out_h) * out_width + out_w) * channels + ic; auto wei_value = read_array(weights, weights_idx); auto in_value = read_array(feature_map, in_idx); // Perform actual accumulation and store in output vector - acc[out_idx] += in_value * wei_value; + sum += in_value * wei_value; } } - } - // Apply bias. - for (size_t ic = 0; ic < channels; ++ic) { - auto out_idx = ((b * out_height + out_h) * out_width + out_w) * channels; - acc[out_idx + ic] += read_array(bias, ic); + auto out_idx = out_base + ic; + sum = sum + (T)read_array(bias, ic); + write_array(dst.data(), out_idx, sum); } } } } - - // Apply clamping to accumulator, cast to FP16 and store in output vector at the same idx. - for (size_t i = 0; i < out_size; i++) { - acc[i] = (clamp_min > acc[i]) ? clamp_min : acc[i]; - acc[i] = (clamp_max < acc[i]) ? clamp_max : acc[i]; - write_array(dst.data(), i, acc[i]); - } - return dst; } diff --git a/test/reference/depthwise_conv.hpp b/test/reference/depthwise_conv.hpp index f3982e86..edbd9b16 100644 --- a/test/reference/depthwise_conv.hpp +++ b/test/reference/depthwise_conv.hpp @@ -13,6 +13,32 @@ namespace kai::test { +struct Padding2D { + size_t left; + size_t right; + size_t top; + size_t bottom; + + struct Hash { + size_t operator()(const Padding2D pad) const { + return // + (std::hash{}(pad.left) << 0) ^ // + (std::hash{}(pad.right) << 1) ^ // + (std::hash{}(pad.top) << 2) ^ // + (std::hash{}(pad.bottom) << 3); + } + }; + +private: + friend bool operator==(const Padding2D& lhs, const Padding2D& rhs) { + return // + lhs.left == rhs.left && lhs.right == rhs.right && lhs.top == rhs.top && lhs.bottom == rhs.bottom; + } + friend std::ostream& operator<<(std::ostream& os, const Padding2D& shape); +}; + +void PrintTo(const Padding2D& pad, std::ostream* os); + /// Depthwise Convolution function /// /// @tparam T Data type. @@ -28,12 +54,13 @@ namespace kai::test { /// @param[in] bias Ptr to start of bias buffer. /// @param[in] clamp_min float value to clamp output to (lower bound). /// @param[in] clamp_max float value to clamp output to (upper bound). +/// @param[in] pad Padding object. /// /// @return The result data buffer. template Buffer depthwise_reference( const size_t batches, const size_t in_height, const size_t in_width, const size_t channels, const size_t filter_height, const size_t filter_width, const void* feature_map, const void* weights, - const void* bias, float clamp_min, float clamp_max); + const void* bias, const Padding2D& pad); } // namespace kai::test diff --git a/test/tests/depthwise_planar_test.cpp b/test/tests/depthwise_planar_test.cpp new file mode 100644 index 00000000..7c71c28d --- /dev/null +++ b/test/tests/depthwise_planar_test.cpp @@ -0,0 +1,305 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include +#include +#include +#include + +#include "kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.h" +#include "kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p_interface.h" +#include "kai/ukernels/dw_conv/pack/kai_rhs_dw_conv_pack_x32p1vl_x32_sme.h" +#include "test/common/buffer.hpp" +#include "test/common/compare.hpp" +#include "test/common/cpu_info.hpp" +#include "test/common/matmul_test_common.hpp" +#include "test/common/matrix_portion.hpp" +#include "test/common/memory.hpp" +#include "test/common/round.hpp" +#include "test/common/sme.hpp" +#include "test/reference/clamp.hpp" +#include "test/reference/depthwise_conv.cpp" +#include "test/reference/depthwise_conv.hpp" +#include "test/reference/fill.hpp" + +namespace kai::test { + +namespace { + +/// Interface for depthwise kernel. +struct DepthwisePlanarKernel { + std::function get_dst_size; + std::function get_dst_offset; + std::function get_m_step; + std::function + conv; +}; + +// Rhs packing kernel. +struct RhsPackDepthwiseKernel { + std::function get_rhs_packed_size; + std::function + pack; +}; + +/// Description of a Depthwise kernel set +struct Depthwise { + std::string_view name; + std::function is_supported; + std::pair filter; + DataType data_type; + DataType acc_type; + RhsPackDepthwiseKernel rhs; + DepthwisePlanarKernel depthwise; +}; + +/// Convenience types for testing. +using DepthwiseArray = std::array; +using DepthwiseParamsParams = std::tuple; +using DepthwisePlanarTest = testing::TestWithParam; + +/// Use interface for depthwise kernel +const kai_dw_conv_f32_f32_f32p_planar_ukernel& get_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla() { + static kai_dw_conv_f32_f32_f32p_planar_ukernel ukernel; + ukernel.get_m_step = kai_get_m_step_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla; + ukernel.get_dst_offset = kai_get_dst_offset_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla; + ukernel.get_dst_size = kai_get_dst_size_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla; + ukernel.run_dw_conv = kai_run_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla; + return ukernel; +} + +const DepthwiseArray& get_depthwise_methods() { + // FP32 kernels with 3x3 filter. + static DepthwiseArray depthwise_methods{}; + depthwise_methods[0].name = "kai_depthwise_planar_f32p4_3x3_s1_4_sme2_mla"; + depthwise_methods[0].rhs.get_rhs_packed_size = kai_rhs_get_dst_size_dw_conv_pack_x32p1vl_x32_sme; + depthwise_methods[0].rhs.pack = kai_run_rhs_dw_conv_pack_x32p1vl_x32_sme; + depthwise_methods[0].is_supported = cpu_has_sme2; + depthwise_methods[0].filter = {3, 3}; + + const kai_dw_conv_f32_f32_f32p_planar_ukernel& ukernel_f32 = + get_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla(); + depthwise_methods[0].data_type = DataType::FP32; + depthwise_methods[0].acc_type = DataType::FP32; + depthwise_methods[0].depthwise.get_m_step = ukernel_f32.get_m_step; + depthwise_methods[0].depthwise.get_dst_size = ukernel_f32.get_dst_size; + depthwise_methods[0].depthwise.get_dst_offset = ukernel_f32.get_dst_offset; + depthwise_methods[0].depthwise.conv = ukernel_f32.run_dw_conv; + return depthwise_methods; +} + +/// Test reference identification. +struct TestDataId { + MatMulShape in_shape; + MatMulShape wei_shape; + MatMulShape out_shape; + Padding2D pad; + DataType dt; + DataType dt_acc; + float clamp_rate; + + struct Hash { + size_t operator()(const TestDataId& test_id) const { + return // + (MatMulShape::Hash{}(test_id.in_shape) << 0) ^ // + (Padding2D::Hash{}(test_id.pad) << 1) ^ // + (std::hash{}(test_id.clamp_rate) << 2); // + } + }; + +private: + friend bool operator==(const TestDataId& lhs, const TestDataId& rhs) { + return // + lhs.in_shape == rhs.in_shape && // + lhs.pad == rhs.pad && // + lhs.clamp_rate == rhs.clamp_rate; // + } +}; + +/// Test reference data +struct TestData { + Buffer lhs; ///< LHS input matrix + Buffer rhs; ///< RHS input matrix + Buffer bias; ///< Bias vector + Buffer out; ///< Reference depthwise result + Buffer padding; ///< Padding buffer + Range clamp_range; ///< Clamp range +}; + +/// Generate reference data, caches it. +struct ReferenceGenerator { + /// Retrieve reference data for the provided test identification + static const TestData& get_test_reference(const TestDataId test_id) { + static std::unordered_map m_data; + if (const auto itr = m_data.find(test_id); itr != end(m_data)) { + return itr->second; + } + + return m_data[test_id] = generate_reference(test_id); + } + +private: + /// Return incremented seed value + static size_t get_seed() { + static size_t seed = 0; + return seed++; + } + + /// Generate reference data. + static TestData generate_reference(const TestDataId& test_id) { + const auto& [in_shape, wei_shape, out_shape, pad, dt, acc_dt, clamp_rate] = test_id; + + // Generate random input data + Buffer lhs = fill_matrix_random(in_shape.m, in_shape.n * in_shape.k, DataFormat(dt), get_seed()); + Buffer rhs = fill_matrix_random(wei_shape.m, wei_shape.n * wei_shape.k, DataFormat(dt), get_seed()); + Buffer bias = fill_matrix_random(1, out_shape.k, DataFormat(dt), get_seed()); + + // Call reference function + Buffer out = depthwise_reference( + 1, in_shape.m, in_shape.n, in_shape.k, wei_shape.m, wei_shape.n, lhs.data(), rhs.data(), bias.data(), pad); + + const auto [min, max] = + find_clamp_range(dt, out.data(), out_shape.m * out_shape.n * out_shape.k, 1.0F - clamp_rate); + Buffer out_clamped = clamp(dt, out.data(), out_shape.m * out_shape.n * out_shape.k, min, max); + + // Populate reference data + TestData test_reference; + test_reference.lhs = std::move(lhs); + test_reference.rhs = std::move(rhs); + test_reference.bias = std::move(bias); + test_reference.out = std::move(out_clamped); + test_reference.clamp_range = {min, max}; + return test_reference; + }; +}; + +/// Perform RHS packing for depthwise +Buffer pack_rhs(const RhsPackDepthwiseKernel& kernel, const MatMulShape& shape, const TestData& reference) { + // Calculate size, and allocate buffer + const size_t dst_size = kernel.get_rhs_packed_size(shape.m, shape.n, shape.k); + Buffer dst(dst_size); + + // RHS Pack API is subject to change. + kernel.pack(reference.rhs.data(), dst.data(), shape.m, shape.n, shape.m, shape.n, shape.k); + return dst; +} + +Buffer dw_conv( + const DepthwisePlanarKernel& kernel, const Rect& portion, const MatMulShape& in_shape, const MatMulShape& out_shape, + const Padding2D pad, const TestData& reference, const Buffer& rhs_packed, Range clamp_range, DataType type) { + KAI_UNUSED(type); + + const size_t dst_size = kernel.get_dst_size(out_shape.m, out_shape.n, out_shape.k); + Buffer dst(dst_size); + + const size_t dt_size_bytes = data_type_size_in_bits(type) / 8; + const size_t stride_in_row = in_shape.n * in_shape.k * dt_size_bytes; + const size_t stride_out_row = out_shape.n * out_shape.k * dt_size_bytes; + const size_t stride_col = dt_size_bytes; + + // Loop the following. M-Step rows are handled at a time. + for (size_t out_row = portion.start_row(); out_row < portion.end_row(); out_row += kernel.get_m_step()) { + const int start_in_row = out_row - pad.top; + const size_t pad_top = (start_in_row < 0) ? (-start_in_row) : 0; + const size_t in_row = (start_in_row < 0) ? 0 : start_in_row; + + const size_t valid_input_rows = (in_row < in_shape.m) ? (in_shape.m - in_row) : 0; + const size_t valid_out_rows = (out_shape.m - out_row); + + kernel.conv( + reference.lhs.data() + (in_row * stride_in_row), stride_in_row, stride_col, pad_top, pad.left, + valid_input_rows, rhs_packed.data(), reference.bias.data(), dst.data() + (out_row * stride_out_row), + stride_col, stride_out_row, valid_out_rows, clamp_range.min, clamp_range.max, 0.f); + } + + return dst; +} +} // namespace + +/// End-to-end test for depthwise kernels +TEST_P(DepthwisePlanarTest, Output) { + const auto& [method, in_shape, padding, out_portion, clamp_rate] = GetParam(); + if (not method.is_supported()) { + GTEST_SKIP() << "Unsupported CPU feature"; + } + + // Calculate Shapes. + int out_height = (in_shape.m + padding.top + padding.bottom + 1 - method.filter.first); + int out_width = (in_shape.n + padding.left + padding.right + 1 - method.filter.second); + if (out_height <= 0 || out_width <= 0) // Check shapes valid + { + GTEST_SKIP() << "Invalid DCONV Shape"; + } + + const size_t dt_size_bytes = data_type_size_in_bits(method.data_type) / 8; + MatMulShape wei_shape = {method.filter.first, method.filter.second, in_shape.k}; + MatMulShape out_shape = {static_cast(out_height), static_cast(out_width), (in_shape.k)}; + + // 1. Calculate reference. + const TestData& test_data = ReferenceGenerator::get_test_reference( + {in_shape, wei_shape, out_shape, padding, method.data_type, method.acc_type, clamp_rate}); + + // 2. Run DW Conv kernels and compare. + Buffer rhs_packed = pack_rhs(method.rhs, wei_shape, test_data); + const Rect portion = out_portion.compute_portion( + out_shape.m, out_shape.n * out_shape.k, method.depthwise.get_m_step(), (rhs_packed.size() / dt_size_bytes)); + + Buffer out = dw_conv( + method.depthwise, portion, in_shape, out_shape, padding, test_data, rhs_packed, test_data.clamp_range, + method.data_type); + + DefaultMismatchHandler handler(0, 0.0001, 0, 0.001); + const auto success = compare( + out.data(), test_data.out.data(), DataType::FP32, out_shape.m, out_shape.n * out_shape.k, portion, handler); + ASSERT_TRUE(success); +} + +/// Name generator for test case +[[maybe_unused]] static void PrintTo(const DepthwiseParamsParams& param, std::ostream* os) { + const auto& [method, shape, padding, portion, clamp_rate] = param; + *os << method.name << "__"; + PrintTo(shape, os); + PrintTo(padding, os); + *os << "__clamp_rate_" << static_cast(clamp_rate * 100) << "__"; + PrintTo(portion, os); +} + +/// Test parameter listing +INSTANTIATE_TEST_SUITE_P( + Depthwise, DepthwisePlanarTest, + testing::Combine( + testing::ValuesIn(get_depthwise_methods()), // + testing::ValuesIn({ + // clang-format off + // IN_HEIGHT, IN_WIDTH, IN_CHANNELS + MatMulShape{ 4, 4, 1}, // + MatMulShape{ 8, 4, 16}, // + MatMulShape{ 96, 33, 37}, // + MatMulShape{ 99, 22, 51}, // + MatMulShape{ 127, 127, 127}, // + // clang-format on + }), + testing::ValuesIn( + {Padding2D{0, 0, 0, 0}, Padding2D{0, 1, 0, 1}, Padding2D{1, 1, 1, 1}, Padding2D{5, 11, 7, 3}}), // + testing::ValuesIn({ + // clang-format off + // (Start row , start col , height , width) + MatrixPortion( 0 , 0 , 1 , 1 ), // Full matrix. + // clang-format on + }), + testing::ValuesIn(std::initializer_list{1.f})), // + testing::PrintToStringParamName()); + +} // namespace kai::test -- GitLab From 26982c7fbde99002dabedeb8cc9166090b1ee970 Mon Sep 17 00:00:00 2001 From: Mohammed Suhail Munshi Date: Fri, 25 Jul 2025 13:21:58 +0100 Subject: [PATCH 5/6] Remove inline assembly to enable MSVC support and remove clang warnings Signed-off-by: Mohammed Suhail Munshi --- CMakeLists.txt | 1 + .../CMakeLists.txt | 2 + ...2_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.c | 452 ++---------------- ...2_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.h | 14 +- ...2_f32p1vl_3x3_s1_4xc_planar_sme2_mla_asm.S | 438 +++++++++++++++++ 5 files changed, 486 insertions(+), 421 deletions(-) create mode 100644 kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla_asm.S diff --git a/CMakeLists.txt b/CMakeLists.txt index 0b60fd9d..d4adf59c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -310,6 +310,7 @@ set(KLEIDIAI_FILES_SME2_ASM kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot_asm.S kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_asm.S + kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla_asm.S ) set(KLEIDIAI_FILES_SME2 diff --git a/examples/dw_conv_f32_f32_f32p_planar_sme2/CMakeLists.txt b/examples/dw_conv_f32_f32_f32p_planar_sme2/CMakeLists.txt index 86dbb2d9..b58eff3f 100644 --- a/examples/dw_conv_f32_f32_f32p_planar_sme2/CMakeLists.txt +++ b/examples/dw_conv_f32_f32_f32p_planar_sme2/CMakeLists.txt @@ -5,6 +5,7 @@ # cmake_minimum_required(VERSION 3.16) +enable_language(ASM) project(dw_conv_f32_f32_f32p_planar_sme2) @@ -17,6 +18,7 @@ include_directories(${KAI_PATH}) # Files requires to build the executable add_executable( dw_conv_f32_f32_f32p_planar_sme2 dconv.cpp + "${KAI_PATH}/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla_asm.S" "${KAI_PATH}/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.c" "${KAI_PATH}/kai/ukernels/dw_conv/pack/kai_rhs_dw_conv_pack_x32p1vl_x32_sme.c" ) diff --git a/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.c b/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.c index 3c50e885..9abbc1df 100644 --- a/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.c +++ b/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.c @@ -17,15 +17,32 @@ // Number of rows iterated through each call. static const size_t kai_mr = 4; - -// Filter/Kernel Height and width. +// Filter/Kernel height and width static const size_t kai_kh = 3; static const size_t kai_kw = 3; - -size_t kai_get_dst_size_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla( - const size_t out_height, const size_t out_width, const size_t num_channels) { - return out_height * out_width * num_channels * sizeof(float); -} +static const size_t kai_kr = 1; + +typedef struct { + const void* inptr; + size_t pad_top; + size_t pad_bottom; + const void* weights; + size_t input_cols; + size_t output_cols; + void** outptrs; + const void* ld_out_cols; + size_t ld_in_vl; + const void* ld_out_vls; + float clamp_min; + float clamp_max; + size_t pad_left; + const void* bias; + size_t current_channel; + size_t n_channels; +} KernelArgs; + +void kai_kernel_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla( + const KernelArgs* args, size_t ld_in_row, size_t ld_in_col); size_t kai_get_m_step_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla(void) { return kai_mr; @@ -39,6 +56,15 @@ size_t kai_get_filter_width_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla(v return kai_kw; } +size_t kai_get_kr_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla(void) { + return kai_kr; +} + +size_t kai_get_dst_size_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla( + const size_t out_height, const size_t out_width, const size_t num_channels) { + return out_height * out_width * num_channels * sizeof(float); +} + size_t kai_get_dst_offset_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla( size_t out_row_idx, size_t stride_out_row) { KAI_ASSUME(out_row_idx % kai_mr == 0); @@ -48,27 +74,13 @@ size_t kai_get_dst_offset_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla( void kai_run_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla( const void* inptr, size_t stride_in_row, size_t stride_in_col, unsigned int pad_top, unsigned int pad_left, unsigned int valid_input_rows, const void* weights, const void* bias, void* outptr_start, size_t stride_out_col, - size_t stride_out_row, unsigned int valid_out_rows, float act_min, float act_max, float pad_value) { + size_t stride_out_row, unsigned int valid_out_rows, float act_min, float act_max, const float pad_value) { KAI_ASSUME(inptr != NULL); KAI_ASSUME(weights != NULL); KAI_ASSUME(outptr_start != NULL); KAI_ASSUME(valid_out_rows != 0); KAI_ASSUME(pad_value == 0.0F); - typedef struct { - const void* inptr; - size_t ld_in_vl; - long unsigned int pad_top, pad_bottom, pad_left; - const void* weights; - const void* bias; - long unsigned int input_cols, output_cols; - void** outptrs; - const size_t* ld_out_cols; - const size_t* ld_out_vls; - long unsigned int current_channel, n_channels; - float clamp_min, clamp_max; - } Args; - // Create padding row. float pad_row[KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(float)] = {0}; @@ -98,7 +110,7 @@ void kai_run_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla( } } - Args args; + KernelArgs args; args.inptr = inptr; args.ld_in_vl = ld_in_vl; args.pad_top = pad_top; @@ -116,398 +128,6 @@ void kai_run_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla( args.clamp_min = act_min; args.clamp_max = act_max; - __asm__ __volatile__( - "ldr x7, [%x[args], %[offsetof_Args_pad_bottom]]\n" - "mov x20, #0x6\n" - ".inst 0xd503477f // SMSTART ZA\n" - "ldr x17, [%x[args], %[offsetof_Args_pad_top]]\n" - "ptrue p2.b\n" - ".inst 0x25207812 // ptrue pn10.b\n" - "ldr x16, [%x[args], %[offsetof_Args_n_channels]]\n" - "ld1rw { z3.s }, p2/Z, [%x[args], %[offsetof_Args_clamp_min]]\n" - "sub x20, x20, x7\n" - "ldr x15, [%x[args], %[offsetof_Args_current_channel]]\n" - "ld1rw { z9.s }, p2/Z, [%x[args], %[offsetof_Args_clamp_max]]\n" - "whilelt p1.s, XZR, x16\n" - "whilelt p9.s, XZR, x20\n" - "whilelt p8.s, XZR, x17\n" - "eor p8.b, p2/Z, p8.b, p9.b\n" - "1:" // Channel loop - "ldr x20, [%x[args], %[offsetof_Args_bias]]\n" - "fmov z16.s, #0x0\n" - "cbz x20, 2f\n" - "ld1w { z16.s }, p1/Z, [x20, x15, LSL #2]\n" - "2:" // Load bias: Done - "ldr x14, [%x[args], %[offsetof_Args_input_cols]]\n" - "mov x23, #0x6\n" - "add x20, x17, x7\n" - "mov z17.d, z16.d\n" - "ldr x22, [%x[args], %[offsetof_Args_weights]]\n" - "lsl x21, %x[ld_in_row], #0x2\n" - "mov z18.d, z16.d\n" - "mov z19.d, z16.d\n" - "ldr x13, [%x[args], %[offsetof_Args_inptr]]\n" - "mov x8, #0x0\n" - "sub x23, x23, x20\n" - "sub x20, x14, #0x1\n" - "ldr x11, [%x[args], %[offsetof_Args_output_cols]]\n" - ".inst 0xa0404ace // ld1w { z14.s-z15.s }, pn10.b/Z, [x22]\n" - "orr x20, x20, %x[ld_in_col], LSL #18\n" - "ld1w { z11.s }, p2/Z, [x22, #2, MUL VL]\n" - "addvl x22, x22, #3\n" - "orr x20, x16, x20, LSL #20\n" - ".inst 0xa0404acc // ld1w { z12.s-z13.s }, pn10.b/Z, [x22]\n" - "lsl x20, x20, #0x2\n" - "madd x21, x21, x17, x13\n" - "ld1w { z0.s }, p2/Z, [x22, #2, MUL VL]\n" - "addvl x22, x22, #3\n" - ".inst 0xa0404ac4 // ld1w { z4.s-z5.s }, pn10.b/Z, [x22]\n" - "ld1w { z7.s }, p2/Z, [x22, #2, MUL VL]\n" - "3:" // Issue prefetches - "subs x23, x23, #0x1\n" - ".inst 0xf8b44abc // rprfm pldstrm, x20, [x21]\n" - "add x21, x21, %x[ld_in_col], LSL #2\n" - "bgt 3b\n" - "ldr x22, [%x[args], %[offsetof_Args_outptrs]]\n" - "lsl x21, %x[ld_in_row], #0x2\n" - ".inst 0xc0040e00 // mova za.d[x8, #0], { z16.d-z19.d }\n" - "mov x10, #0x2\n" - "ldr x20, [%x[args], %[offsetof_Args_ld_out_cols]]\n" - "msub x13, x17, x21, x13\n" - ".inst 0xc0040e01 // mova za.d[x8, #1], { z16.d-z19.d }\n" - "ldr x21, [%x[args], %[offsetof_Args_pad_left]]\n" - ".inst 0xc0040e02 // mova za.d[x8, #2], { z16.d-z19.d }\n" - "ldp x9, x28, [x22], #0x10\n" - "ldp x27, x26, [x20], #0x10\n" - "ldp x25, x24, [x22], #0x10\n" - "ldp x23, x22, [x20], #0x10\n" - "cbz x21, 5f\n" - "cmp x21, x10\n" - "csel x20, x21, x10, LT\n" - "sub x21, x21, x20\n" - "sub x10, x10, x20\n" - "cbz x21, 5f\n" - ".inst 0xc0060c1c // mova { z28.d-z31.d }, za.d[x8, #0]\n" - "sub x11, x11, x21\n" - ".inst 0xc1a9c87c // fclamp { z28.s-z31.s }, z3.s, z9.s\n" - "4:" // Left padding - "subs x21, x21, #0x1\n" - "st1w { z28.s }, p1, [x9]\n" - "add x9, x9, x27, LSL #2\n" - "st1w { z29.s }, p1, [x28]\n" - "add x28, x28, x26, LSL #2\n" - "st1w { z30.s }, p1, [x25]\n" - "add x25, x25, x23, LSL #2\n" - "st1w { z31.s }, p1, [x24]\n" - "add x24, x24, x22, LSL #2\n" - "bgt 4b\n" - "5:" // Left padding: End - "adds XZR, x17, x7\n" - "bne 10f\n" - "cbz x10, 8f\n" - "cmp x10, #0x1\n" - "sub x14, x14, x10\n" - "beq 7f\n" - "add x20, x13, %x[ld_in_row], LSL #2\n" - "ld1w { z22.s }, p1/Z, [x13]\n" - "add x13, x13, %x[ld_in_col], LSL #2\n" - "ld1w { z23.s }, p1/Z, [x20]\n" - "add x20, x20, %x[ld_in_row], LSL #2\n" - "ld1w { z24.s }, p1/Z, [x20]\n" - "add x20, x20, %x[ld_in_row], LSL #2\n" - "ld1w { z25.s }, p1/Z, [x20]\n" - "add x20, x20, %x[ld_in_row], LSL #2\n" - "ld1w { z26.s }, p1/Z, [x20]\n" - "add x20, x20, %x[ld_in_row], LSL #2\n" - "ld1w { z27.s }, p1/Z, [x20]\n" - ".inst 0xc13e1ac0 // fmla za.s[x8, 0], { z22.s-z25.s }, z14.s\n" - ".inst 0xc13c1ae0 // fmla za.s[x8, 0], { z23.s-z26.s }, z12.s\n" - ".inst 0xc1341b00 // fmla za.s[x8, 0], { z24.s-z27.s }, z4.s\n" - "7:" // Unpadded: 1 priming loads - "add x20, x13, %x[ld_in_row], LSL #2\n" - "ld1w { z24.s }, p1/Z, [x13]\n" - "add x13, x13, %x[ld_in_col], LSL #2\n" - "ld1w { z25.s }, p1/Z, [x20]\n" - "add x20, x20, %x[ld_in_row], LSL #2\n" - "ld1w { z26.s }, p1/Z, [x20]\n" - "add x20, x20, %x[ld_in_row], LSL #2\n" - "ld1w { z27.s }, p1/Z, [x20]\n" - "add x20, x20, %x[ld_in_row], LSL #2\n" - "ld1w { z28.s }, p1/Z, [x20]\n" - "add x20, x20, %x[ld_in_row], LSL #2\n" - "ld1w { z29.s }, p1/Z, [x20]\n" - ".inst 0xc13f1b00 // fmla za.s[x8, 0], { z24.s-z27.s }, z15.s\n" - ".inst 0xc13e1b01 // fmla za.s[x8, 1], { z24.s-z27.s }, z14.s\n" - ".inst 0xc13d1b20 // fmla za.s[x8, 0], { z25.s-z28.s }, z13.s\n" - ".inst 0xc13c1b21 // fmla za.s[x8, 1], { z25.s-z28.s }, z12.s\n" - ".inst 0xc1351b40 // fmla za.s[x8, 0], { z26.s-z29.s }, z5.s\n" - ".inst 0xc1341b41 // fmla za.s[x8, 1], { z26.s-z29.s }, z4.s\n" - "8:" // Unpadded: 0 priming loads - "cbz x14, 16f\n" - "add x20, x13, %x[ld_in_row], LSL #2\n" - "ld1w { z20.s }, p1/Z, [x13]\n" - "sub x14, x14, #0x1\n" - "ld1w { z21.s }, p1/Z, [x20]\n" - "add x20, x20, %x[ld_in_row], LSL #2\n" - "sub x11, x11, #0x1\n" - "ld1w { z22.s }, p1/Z, [x20]\n" - "add x20, x20, %x[ld_in_row], LSL #2\n" - "cmp x14, x11\n" - "ld1w { z23.s }, p1/Z, [x20]\n" - "add x20, x20, %x[ld_in_row], LSL #2\n" - "csel x21, x14, x11, LT\n" - "ld1w { z24.s }, p1/Z, [x20]\n" - "add x20, x20, %x[ld_in_row], LSL #2\n" - "add x13, x13, %x[ld_in_col], LSL #2\n" - "ld1w { z25.s }, p1/Z, [x20]\n" - "sub x11, x11, x21\n" - "cbz x21, 15f\n" - "9:" // Unpadded: Main loop - ".inst 0xc13b1a80 // fmla za.s[x8, 0], { z20.s-z23.s }, z11.s\n" - "add x20, x13, %x[ld_in_row], LSL #2\n" - "subs x21, x21, #0x1\n" - ".inst 0xc13f1a81 // fmla za.s[x8, 1], { z20.s-z23.s }, z15.s\n" - ".inst 0xc13e1a82 // fmla za.s[x8, 2], { z20.s-z23.s }, z14.s\n" - "ld1w { z20.s }, p1/Z, [x13]\n" - "add x13, x13, %x[ld_in_col], LSL #2\n" - ".inst 0xc1301aa0 // fmla za.s[x8, 0], { z21.s-z24.s }, z0.s\n" - ".inst 0xc13d1aa1 // fmla za.s[x8, 1], { z21.s-z24.s }, z13.s\n" - ".inst 0xc13c1aa2 // fmla za.s[x8, 2], { z21.s-z24.s }, z12.s\n" - "ld1w { z21.s }, p1/Z, [x20]\n" - "add x20, x20, %x[ld_in_row], LSL #2\n" - ".inst 0xc1371ac0 // fmla za.s[x8, 0], { z22.s-z25.s }, z7.s\n" - ".inst 0xc1351ac1 // fmla za.s[x8, 1], { z22.s-z25.s }, z5.s\n" - ".inst 0xc1341ac2 // fmla za.s[x8, 2], { z22.s-z25.s }, z4.s\n" - "ld1w { z22.s }, p1/Z, [x20]\n" - "add x20, x20, %x[ld_in_row], LSL #2\n" - "ld1w { z23.s }, p1/Z, [x20]\n" - "add x20, x20, %x[ld_in_row], LSL #2\n" - ".inst 0xc0060c1c // mova { z28.d-z31.d }, za.d[x8, #0]\n" - "add x8, x8, #0x1\n" - "ld1w { z24.s }, p1/Z, [x20]\n" - "add x20, x20, %x[ld_in_row], LSL #2\n" - ".inst 0xc0040e02 // mova za.d[x8, #2], { z16.d-z19.d }\n" - "ld1w { z25.s }, p1/Z, [x20]\n" - ".inst 0xc1a9c87c // fclamp { z28.s-z31.s }, z3.s, z9.s\n" - "st1w { z28.s }, p1, [x9]\n" - "add x9, x9, x27, LSL #2\n" - "st1w { z29.s }, p1, [x28]\n" - "add x28, x28, x26, LSL #2\n" - "st1w { z30.s }, p1, [x25]\n" - "add x25, x25, x23, LSL #2\n" - "st1w { z31.s }, p1, [x24]\n" - "add x24, x24, x22, LSL #2\n" - "bgt 9b\n" - "b 15f\n" - "10:" // Padded - "cbz x10, 13f\n" - "cmp x10, #0x1\n" - "sub x14, x14, x10\n" - "beq 12f\n" - "mov x12, #0x0\n" - "add x20, x13, %x[ld_in_row], LSL #2\n" - ".inst 0x25304500 // psel p0.s, p1.s/Z, p8.s[w12]\n" - "ld1w { z23.s }, p0/Z, [x13]\n" - ".inst 0x25704500 // psel p0.s, p1.s/Z, p8.s[w12, #1]\n" - "add x13, x13, %x[ld_in_col], LSL #2\n" - "ld1w { z24.s }, p0/Z, [x20]\n" - ".inst 0x25b04500 // psel p0.s, p1.s/Z, p8.s[w12, #2]\n" - "add x20, x20, %x[ld_in_row], LSL #2\n" - "ld1w { z25.s }, p0/Z, [x20]\n" - ".inst 0x25f04500 // psel p0.s, p1.s/Z, p8.s[w12, #3]\n" - "add x20, x20, %x[ld_in_row], LSL #2\n" - "mov x12, #0x4\n" - "ld1w { z26.s }, p0/Z, [x20]\n" - ".inst 0x25304500 // psel p0.s, p1.s/Z, p8.s[w12]\n" - "add x20, x20, %x[ld_in_row], LSL #2\n" - ".inst 0xc13e1ae0 // fmla za.s[x8, 0], { z23.s-z26.s }, z14.s\n" - "ld1w { z27.s }, p0/Z, [x20]\n" - ".inst 0x25704500 // psel p0.s, p1.s/Z, p8.s[w12, #1]\n" - "add x20, x20, %x[ld_in_row], LSL #2\n" - "ld1w { z28.s }, p0/Z, [x20]\n" - ".inst 0xc13c1b00 // fmla za.s[x8, 0], { z24.s-z27.s }, z12.s\n" - ".inst 0xc1341b20 // fmla za.s[x8, 0], { z25.s-z28.s }, z4.s\n" - "12:" // Padded: 1 priming loads - "mov x12, #0x0\n" - "add x20, x13, %x[ld_in_row], LSL #2\n" - ".inst 0x25304500 // psel p0.s, p1.s/Z, p8.s[w12]\n" - "ld1w { z25.s }, p0/Z, [x13]\n" - ".inst 0x25704500 // psel p0.s, p1.s/Z, p8.s[w12, #1]\n" - "add x13, x13, %x[ld_in_col], LSL #2\n" - "ld1w { z26.s }, p0/Z, [x20]\n" - ".inst 0x25b04500 // psel p0.s, p1.s/Z, p8.s[w12, #2]\n" - "add x20, x20, %x[ld_in_row], LSL #2\n" - "ld1w { z27.s }, p0/Z, [x20]\n" - ".inst 0x25f04500 // psel p0.s, p1.s/Z, p8.s[w12, #3]\n" - "add x20, x20, %x[ld_in_row], LSL #2\n" - "mov x12, #0x4\n" - "ld1w { z28.s }, p0/Z, [x20]\n" - ".inst 0x25304500 // psel p0.s, p1.s/Z, p8.s[w12]\n" - "add x20, x20, %x[ld_in_row], LSL #2\n" - ".inst 0xc13f1b20 // fmla za.s[x8, 0], { z25.s-z28.s }, z15.s\n" - "ld1w { z29.s }, p0/Z, [x20]\n" - ".inst 0x25704500 // psel p0.s, p1.s/Z, p8.s[w12, #1]\n" - "add x20, x20, %x[ld_in_row], LSL #2\n" - ".inst 0xc13e1b21 // fmla za.s[x8, 1], { z25.s-z28.s }, z14.s\n" - "ld1w { z30.s }, p0/Z, [x20]\n" - ".inst 0xc13d1b40 // fmla za.s[x8, 0], { z26.s-z29.s }, z13.s\n" - ".inst 0xc13c1b41 // fmla za.s[x8, 1], { z26.s-z29.s }, z12.s\n" - ".inst 0xc1351b60 // fmla za.s[x8, 0], { z27.s-z30.s }, z5.s\n" - ".inst 0xc1341b61 // fmla za.s[x8, 1], { z27.s-z30.s }, z4.s\n" - "13:" // Padded: 0 priming loads - "cbz x14, 16f\n" - "mov x12, #0x0\n" - "add x20, x13, %x[ld_in_row], LSL #2\n" - ".inst 0x25304500 // psel p0.s, p1.s/Z, p8.s[w12]\n" - "sub x14, x14, #0x1\n" - "sub x11, x11, #0x1\n" - "cmp x14, x11\n" - "ld1w { z20.s }, p0/Z, [x13]\n" - ".inst 0x25704500 // psel p0.s, p1.s/Z, p8.s[w12, #1]\n" - "csel x21, x14, x11, LT\n" - "add x13, x13, %x[ld_in_col], LSL #2\n" - "sub x11, x11, x21\n" - "ld1w { z21.s }, p0/Z, [x20]\n" - ".inst 0x25b04500 // psel p0.s, p1.s/Z, p8.s[w12, #2]\n" - "add x20, x20, %x[ld_in_row], LSL #2\n" - "ld1w { z22.s }, p0/Z, [x20]\n" - ".inst 0x25f04500 // psel p0.s, p1.s/Z, p8.s[w12, #3]\n" - "add x20, x20, %x[ld_in_row], LSL #2\n" - "mov x12, #0x4\n" - "ld1w { z23.s }, p0/Z, [x20]\n" - ".inst 0x25304500 // psel p0.s, p1.s/Z, p8.s[w12]\n" - "add x20, x20, %x[ld_in_row], LSL #2\n" - "ld1w { z24.s }, p0/Z, [x20]\n" - ".inst 0x25704500 // psel p0.s, p1.s/Z, p8.s[w12, #1]\n" - "add x20, x20, %x[ld_in_row], LSL #2\n" - "ld1w { z25.s }, p0/Z, [x20]\n" - "cbz x21, 15f\n" - "14:" // Padded: Main loop - "mov x12, #0x0\n" - ".inst 0xc13b1a80 // fmla za.s[x8, 0], { z20.s-z23.s }, z11.s\n" - "add x20, x13, %x[ld_in_row], LSL #2\n" - ".inst 0x25304500 // psel p0.s, p1.s/Z, p8.s[w12]\n" - ".inst 0xc13f1a81 // fmla za.s[x8, 1], { z20.s-z23.s }, z15.s\n" - "subs x21, x21, #0x1\n" - ".inst 0xc13e1a82 // fmla za.s[x8, 2], { z20.s-z23.s }, z14.s\n" - "ld1w { z20.s }, p0/Z, [x13]\n" - ".inst 0x25704500 // psel p0.s, p1.s/Z, p8.s[w12, #1]\n" - ".inst 0xc1301aa0 // fmla za.s[x8, 0], { z21.s-z24.s }, z0.s\n" - "add x13, x13, %x[ld_in_col], LSL #2\n" - ".inst 0xc13d1aa1 // fmla za.s[x8, 1], { z21.s-z24.s }, z13.s\n" - ".inst 0xc13c1aa2 // fmla za.s[x8, 2], { z21.s-z24.s }, z12.s\n" - "ld1w { z21.s }, p0/Z, [x20]\n" - ".inst 0x25b04500 // psel p0.s, p1.s/Z, p8.s[w12, #2]\n" - "add x20, x20, %x[ld_in_row], LSL #2\n" - ".inst 0xc1371ac0 // fmla za.s[x8, 0], { z22.s-z25.s }, z7.s\n" - ".inst 0xc1351ac1 // fmla za.s[x8, 1], { z22.s-z25.s }, z5.s\n" - ".inst 0xc1341ac2 // fmla za.s[x8, 2], { z22.s-z25.s }, z4.s\n" - "ld1w { z22.s }, p0/Z, [x20]\n" - ".inst 0x25f04500 // psel p0.s, p1.s/Z, p8.s[w12, #3]\n" - "add x20, x20, %x[ld_in_row], LSL #2\n" - "mov x12, #0x4\n" - ".inst 0xc0060c1c // mova { z28.d-z31.d }, za.d[x8, #0]\n" - "add x8, x8, #0x1\n" - "ld1w { z23.s }, p0/Z, [x20]\n" - ".inst 0x25304500 // psel p0.s, p1.s/Z, p8.s[w12]\n" - "add x20, x20, %x[ld_in_row], LSL #2\n" - ".inst 0xc0040e02 // mova za.d[x8, #2], { z16.d-z19.d }\n" - "ld1w { z24.s }, p0/Z, [x20]\n" - ".inst 0x25704500 // psel p0.s, p1.s/Z, p8.s[w12, #1]\n" - "add x20, x20, %x[ld_in_row], LSL #2\n" - ".inst 0xc1a9c87c // fclamp { z28.s-z31.s }, z3.s, z9.s\n" - "ld1w { z25.s }, p0/Z, [x20]\n" - "st1w { z28.s }, p1, [x9]\n" - "add x9, x9, x27, LSL #2\n" - "st1w { z29.s }, p1, [x28]\n" - "add x28, x28, x26, LSL #2\n" - "st1w { z30.s }, p1, [x25]\n" - "add x25, x25, x23, LSL #2\n" - "st1w { z31.s }, p1, [x24]\n" - "add x24, x24, x22, LSL #2\n" - "bgt 14b\n" - "15:" // Main loop tail - ".inst 0xc13b1a80 // fmla za.s[x8, 0], { z20.s-z23.s }, z11.s\n" - ".inst 0xc13f1a81 // fmla za.s[x8, 1], { z20.s-z23.s }, z15.s\n" - ".inst 0xc13e1a82 // fmla za.s[x8, 2], { z20.s-z23.s }, z14.s\n" - ".inst 0xc1301aa0 // fmla za.s[x8, 0], { z21.s-z24.s }, z0.s\n" - ".inst 0xc13d1aa1 // fmla za.s[x8, 1], { z21.s-z24.s }, z13.s\n" - ".inst 0xc13c1aa2 // fmla za.s[x8, 2], { z21.s-z24.s }, z12.s\n" - ".inst 0xc1371ac0 // fmla za.s[x8, 0], { z22.s-z25.s }, z7.s\n" - ".inst 0xc1351ac1 // fmla za.s[x8, 1], { z22.s-z25.s }, z5.s\n" - ".inst 0xc1341ac2 // fmla za.s[x8, 2], { z22.s-z25.s }, z4.s\n" - ".inst 0xc0060c14 // mova { z20.d-z23.d }, za.d[x8, #0]\n" - "add x8, x8, #0x1\n" - ".inst 0xc0040e02 // mova za.d[x8, #2], { z16.d-z19.d }\n" - ".inst 0xc1a9c874 // fclamp { z20.s-z23.s }, z3.s, z9.s\n" - "st1w { z20.s }, p1, [x9]\n" - "add x9, x9, x27, LSL #2\n" - "st1w { z21.s }, p1, [x28]\n" - "add x28, x28, x26, LSL #2\n" - "st1w { z22.s }, p1, [x25]\n" - "add x25, x25, x23, LSL #2\n" - "st1w { z23.s }, p1, [x24]\n" - "add x24, x24, x22, LSL #2\n" - "16:" // Main loop skip tail - "cbz x11, 18f\n" - "17:" // Right padding loop - ".inst 0xc0060c04 // mova { z4.d-z7.d }, za.d[x8, #0]\n" - "add x8, x8, #0x1\n" - "subs x11, x11, #0x1\n" - ".inst 0xc0040e02 // mova za.d[x8, #2], { z16.d-z19.d }\n" - ".inst 0xc1a9c864 // fclamp { z4.s-z7.s }, z3.s, z9.s\n" - "st1w { z4.s }, p1, [x9]\n" - "add x9, x9, x27, LSL #2\n" - "st1w { z5.s }, p1, [x28]\n" - "add x28, x28, x26, LSL #2\n" - "st1w { z6.s }, p1, [x25]\n" - "add x25, x25, x23, LSL #2\n" - "st1w { z7.s }, p1, [x24]\n" - "add x24, x24, x22, LSL #2\n" - "bgt 17b\n" - "18:" // End - "ldr x20, [%x[args], %[offsetof_Args_weights]]\n" - "incw x15\n" - "whilelt p1.s, x15, x16\n" - "incb x20, ALL, MUL #9\n" - "str x20, [%x[args], %[offsetof_Args_weights]]\n" - "ldr x21, [%x[args], %[offsetof_Args_ld_in_vl]]\n" - "ldr x20, [%x[args], %[offsetof_Args_inptr]]\n" - "add x20, x20, x21, LSL #2\n" - "str x20, [%x[args], %[offsetof_Args_inptr]]\n" - "ldr x25, [%x[args], %[offsetof_Args_outptrs]]\n" - "ldr x24, [%x[args], %[offsetof_Args_ld_out_vls]]\n" - "ldp x23, x22, [x25, #0x0]\n" - "ldp x21, x20, [x24, #0x0]\n" - "add x23, x23, x21, LSL #2\n" - "add x22, x22, x20, LSL #2\n" - "stp x23, x22, [x25, #0x0]\n" - "ldp x23, x22, [x25, #0x10]\n" - "ldp x21, x20, [x24, #0x10]\n" - "add x23, x23, x21, LSL #2\n" - "add x22, x22, x20, LSL #2\n" - "stp x23, x22, [x25, #0x10]\n" - "b.any 1b\n" - ".inst 0xd503467f // SMSTOP\n" - : - : [args] "r"(&args), [ld_in_col] "r"(ld_in_col), [ld_in_row] "r"(ld_in_row), - [offsetof_Args_bias] "I"(offsetof(Args, bias)), [offsetof_Args_clamp_max] "I"(offsetof(Args, clamp_max)), - [offsetof_Args_clamp_min] "I"(offsetof(Args, clamp_min)), - [offsetof_Args_current_channel] "I"(offsetof(Args, current_channel)), - [offsetof_Args_inptr] "I"(offsetof(Args, inptr)), [offsetof_Args_input_cols] "I"(offsetof(Args, input_cols)), - [offsetof_Args_ld_in_vl] "I"(offsetof(Args, ld_in_vl)), - [offsetof_Args_ld_out_cols] "I"(offsetof(Args, ld_out_cols)), - [offsetof_Args_ld_out_vls] "I"(offsetof(Args, ld_out_vls)), - [offsetof_Args_n_channels] "I"(offsetof(Args, n_channels)), - [offsetof_Args_outptrs] "I"(offsetof(Args, outptrs)), - [offsetof_Args_output_cols] "I"(offsetof(Args, output_cols)), - [offsetof_Args_pad_bottom] "I"(offsetof(Args, pad_bottom)), - [offsetof_Args_pad_left] "I"(offsetof(Args, pad_left)), [offsetof_Args_pad_top] "I"(offsetof(Args, pad_top)), - [offsetof_Args_weights] "I"(offsetof(Args, weights)) - : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", "p7", - "p8", "p9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x20", "x21", "x22", "x23", "x24", "x25", - "x26", "x27", "x28", "x7", "x8", "x9", "z0", "z1", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", - "z18", "z19", "z2", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z3", "z30", "z31", - "z4", "z5", "z6", "z7", "z8", "z9"); + kai_kernel_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla(&args, ld_in_row, ld_in_col); } #endif // Architectural features check. diff --git a/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.h b/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.h index e1b00fcd..920cd2f6 100644 --- a/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.h +++ b/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla.h @@ -12,10 +12,6 @@ extern "C" { #include -/// @return output size in bytes. -size_t kai_get_dst_size_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla( - size_t out_height, size_t out_width, size_t num_channels); - /// @return Maximum number of rows of output data produced by this kernel. size_t kai_get_m_step_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla(void); @@ -25,6 +21,14 @@ size_t kai_get_filter_height_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla( /// @return Height of the filter used by this kernel. size_t kai_get_filter_width_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla(void); +/// This is the packing parameter which must be used to pack the LHS and RHS matrix. +/// @return The kr value. +size_t kai_get_kr_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla(void); + +/// @return output size in bytes. +size_t kai_get_dst_size_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla( + const size_t out_height, const size_t out_width, const size_t num_channels); + /// @param[in] out_row_idx the row index of the output matrix /// @param[in] stride_out_row Output row stride in bytes /// @return offset to element in output/destination matrix. @@ -51,7 +55,7 @@ size_t kai_get_dst_offset_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla(siz void kai_run_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla( const void* inptr, size_t stride_in_row, size_t stride_in_col, unsigned int pad_top, unsigned int pad_left, unsigned int valid_input_rows, const void* weights, const void* bias, void* outptr_start, size_t stride_out_col, - size_t stride_out_row, unsigned int valid_out_rows, float act_min, float act_max, float pad_value); + size_t stride_out_row, unsigned int valid_out_rows, float act_min, float act_max, const float pad_value); #ifdef __cplusplus } // extern "C" diff --git a/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla_asm.S b/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla_asm.S new file mode 100644 index 00000000..a6276c85 --- /dev/null +++ b/kai/ukernels/dw_conv/depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla_asm.S @@ -0,0 +1,438 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla) +KAI_ASM_FUNCTION_LABEL(kai_kernel_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla) + stp x20, x21, [sp, -144]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + stp d8, d9, [sp, 72] + stp d10, d11, [sp, 88] + stp d12, d13, [sp, 104] + stp d14, d15, [sp, 120] + mov x20, #0x6 + KAI_ASM_INST(0xd503477f) // SMSTART ZA + ldr x7, [x0, #0x10] + ptrue p2.b + KAI_ASM_INST(0x25207812) // ptrue pn10.b + ldr x17, [x0, #0x8] + ld1rw { z3.s }, p2/Z, [x0, #80] + ldr x16, [x0, #0x70] + ld1rw { z9.s }, p2/Z, [x0, #84] + sub x20, x20, x7 + ldr x15, [x0, #0x68] + whilelt p1.s, XZR, x16 + whilelt p9.b, XZR, x20 + whilelt p8.b, XZR, x17 + eor p8.b, p2/Z, p8.b, p9.b +KAI_ASM_LABEL(label_1) // Channel loop + ldr x20, [x0, #0x60] + fmov z16.s, #0x0 + cbz x20, label_2 + ld1w { z16.s }, p1/Z, [x20, x15, LSL #2] +KAI_ASM_LABEL(label_2) // Load bias: Done + ldr x14, [x0, #0x20] + mov x23, #0x6 + add x20, x17, x7 + mov z17.d, z16.d + ldr x22, [x0, #0x18] + lsl x21, x1, #0x2 + mov z18.d, z16.d + mov z19.d, z16.d + ldr x13, [x0, #0x0] + mov x8, #0x0 + sub x23, x23, x20 + sub x20, x14, #0x1 + ldr x11, [x0, #0x28] + KAI_ASM_INST(0xa0404ace) // ld1w { z14.s-z15.s }, pn10.b/Z, [x22] + orr x20, x20, x2, LSL #18 + ld1w { z11.s }, p2/Z, [x22, #2, MUL VL] + addvl x22, x22, #3 + orr x20, x16, x20, LSL #20 + KAI_ASM_INST(0xa0404acc) // ld1w { z12.s-z13.s }, pn10.b/Z, [x22] + lsl x20, x20, #0x2 + madd x21, x21, x17, x13 + ld1w { z0.s }, p2/Z, [x22, #2, MUL VL] + addvl x22, x22, #3 + KAI_ASM_INST(0xa0404ac4) // ld1w { z4.s-z5.s }, pn10.b/Z, [x22] + ld1w { z7.s }, p2/Z, [x22, #2, MUL VL] +KAI_ASM_LABEL(label_3) // Issue prefetches + subs x23, x23, #0x1 + KAI_ASM_INST(0xf8b44abc) // rprfm pldstrm, x20, [x21] + add x21, x21, x2, LSL #2 + bgt label_3 + ldr x22, [x0, #0x30] + lsl x21, x1, #0x2 + KAI_ASM_INST(0xc0040e00) // mova za.d[x8, #0], { z16.d-z19.d } + mov x10, #0x2 + ldr x20, [x0, #0x38] + msub x13, x17, x21, x13 + KAI_ASM_INST(0xc0040e01) // mova za.d[x8, #1], { z16.d-z19.d } + ldr x21, [x0, #0x58] + KAI_ASM_INST(0xc0040e02) // mova za.d[x8, #2], { z16.d-z19.d } + ldp x9, x28, [x22], #0x10 + ldp x27, x26, [x20], #0x10 + ldp x25, x24, [x22], #0x10 + ldp x23, x22, [x20], #0x10 + cbz x21, label_5 + cmp x21, x10 + csel x20, x21, x10, LT + sub x21, x21, x20 + sub x10, x10, x20 + cbz x21, label_5 + KAI_ASM_INST(0xc0060c1c) // mova { z28.d-z31.d }, za.d[x8, #0] + sub x11, x11, x21 + KAI_ASM_INST(0xc1a9c87c) // fclamp { z28.s-z31.s }, z3.s, z9.s +KAI_ASM_LABEL(label_4) // Left padding + subs x21, x21, #0x1 + st1w { z28.s }, p1, [x9] + add x9, x9, x27, LSL #2 + st1w { z29.s }, p1, [x28] + add x28, x28, x26, LSL #2 + st1w { z30.s }, p1, [x25] + add x25, x25, x23, LSL #2 + st1w { z31.s }, p1, [x24] + add x24, x24, x22, LSL #2 + bgt label_4 +KAI_ASM_LABEL(label_5) // Left padding: End + adds XZR, x17, x7 + bne label_10 + cbz x10, label_8 + cmp x10, #0x1 + sub x14, x14, x10 + beq label_7 + add x20, x13, x1, LSL #2 + ld1w { z22.s }, p1/Z, [x13] + add x13, x13, x2, LSL #2 + ld1w { z23.s }, p1/Z, [x20] + add x20, x20, x1, LSL #2 + ld1w { z24.s }, p1/Z, [x20] + add x20, x20, x1, LSL #2 + ld1w { z25.s }, p1/Z, [x20] + add x20, x20, x1, LSL #2 + ld1w { z26.s }, p1/Z, [x20] + add x20, x20, x1, LSL #2 + ld1w { z27.s }, p1/Z, [x20] + KAI_ASM_INST(0xc13e1ac0) // fmla za.s[x8, 0], { z22.s-z25.s }, z14.s + KAI_ASM_INST(0xc13c1ae0) // fmla za.s[x8, 0], { z23.s-z26.s }, z12.s + KAI_ASM_INST(0xc1341b00) // fmla za.s[x8, 0], { z24.s-z27.s }, z4.s +KAI_ASM_LABEL(label_7) // Unpadded: 1 priming loads + add x20, x13, x1, LSL #2 + ld1w { z24.s }, p1/Z, [x13] + add x13, x13, x2, LSL #2 + ld1w { z25.s }, p1/Z, [x20] + add x20, x20, x1, LSL #2 + ld1w { z26.s }, p1/Z, [x20] + add x20, x20, x1, LSL #2 + ld1w { z27.s }, p1/Z, [x20] + add x20, x20, x1, LSL #2 + ld1w { z28.s }, p1/Z, [x20] + add x20, x20, x1, LSL #2 + ld1w { z29.s }, p1/Z, [x20] + KAI_ASM_INST(0xc13f1b00) // fmla za.s[x8, 0], { z24.s-z27.s }, z15.s + KAI_ASM_INST(0xc13e1b01) // fmla za.s[x8, 1], { z24.s-z27.s }, z14.s + KAI_ASM_INST(0xc13d1b20) // fmla za.s[x8, 0], { z25.s-z28.s }, z13.s + KAI_ASM_INST(0xc13c1b21) // fmla za.s[x8, 1], { z25.s-z28.s }, z12.s + KAI_ASM_INST(0xc1351b40) // fmla za.s[x8, 0], { z26.s-z29.s }, z5.s + KAI_ASM_INST(0xc1341b41) // fmla za.s[x8, 1], { z26.s-z29.s }, z4.s +KAI_ASM_LABEL(label_8) // Unpadded: 0 priming loads + cbz x14, label_16 + add x20, x13, x1, LSL #2 + ld1w { z20.s }, p1/Z, [x13] + sub x14, x14, #0x1 + ld1w { z21.s }, p1/Z, [x20] + add x20, x20, x1, LSL #2 + sub x11, x11, #0x1 + ld1w { z22.s }, p1/Z, [x20] + add x20, x20, x1, LSL #2 + cmp x14, x11 + ld1w { z23.s }, p1/Z, [x20] + add x20, x20, x1, LSL #2 + csel x21, x14, x11, LT + ld1w { z24.s }, p1/Z, [x20] + add x20, x20, x1, LSL #2 + add x13, x13, x2, LSL #2 + ld1w { z25.s }, p1/Z, [x20] + sub x11, x11, x21 + cbz x21, label_15 +KAI_ASM_LABEL(label_9) // Unpadded: Main loop + KAI_ASM_INST(0xc13b1a80) // fmla za.s[x8, 0], { z20.s-z23.s }, z11.s + add x20, x13, x1, LSL #2 + subs x21, x21, #0x1 + KAI_ASM_INST(0xc13f1a81) // fmla za.s[x8, 1], { z20.s-z23.s }, z15.s + KAI_ASM_INST(0xc13e1a82) // fmla za.s[x8, 2], { z20.s-z23.s }, z14.s + ld1w { z20.s }, p1/Z, [x13] + add x13, x13, x2, LSL #2 + KAI_ASM_INST(0xc1301aa0) // fmla za.s[x8, 0], { z21.s-z24.s }, z0.s + KAI_ASM_INST(0xc13d1aa1) // fmla za.s[x8, 1], { z21.s-z24.s }, z13.s + KAI_ASM_INST(0xc13c1aa2) // fmla za.s[x8, 2], { z21.s-z24.s }, z12.s + ld1w { z21.s }, p1/Z, [x20] + add x20, x20, x1, LSL #2 + KAI_ASM_INST(0xc1371ac0) // fmla za.s[x8, 0], { z22.s-z25.s }, z7.s + KAI_ASM_INST(0xc1351ac1) // fmla za.s[x8, 1], { z22.s-z25.s }, z5.s + KAI_ASM_INST(0xc1341ac2) // fmla za.s[x8, 2], { z22.s-z25.s }, z4.s + ld1w { z22.s }, p1/Z, [x20] + add x20, x20, x1, LSL #2 + ld1w { z23.s }, p1/Z, [x20] + add x20, x20, x1, LSL #2 + KAI_ASM_INST(0xc0060c1c) // mova { z28.d-z31.d }, za.d[x8, #0] + add x8, x8, #0x1 + ld1w { z24.s }, p1/Z, [x20] + add x20, x20, x1, LSL #2 + KAI_ASM_INST(0xc0040e02) // mova za.d[x8, #2], { z16.d-z19.d } + ld1w { z25.s }, p1/Z, [x20] + KAI_ASM_INST(0xc1a9c87c) // fclamp { z28.s-z31.s }, z3.s, z9.s + st1w { z28.s }, p1, [x9] + add x9, x9, x27, LSL #2 + st1w { z29.s }, p1, [x28] + add x28, x28, x26, LSL #2 + st1w { z30.s }, p1, [x25] + add x25, x25, x23, LSL #2 + st1w { z31.s }, p1, [x24] + add x24, x24, x22, LSL #2 + bgt label_9 + b label_15 +KAI_ASM_LABEL(label_10) // Padded + cbz x10, label_13 + cmp x10, #0x1 + sub x14, x14, x10 + beq label_12 + mov x12, #0x0 + add x20, x13, x1, LSL #2 + KAI_ASM_INST(0x25244500) // psel p0.b, p1.b/Z, p8.b[w12] + ld1w { z23.s }, p0/Z, [x13] + KAI_ASM_INST(0x252c4500) // psel p0.b, p1.b/Z, p8.b[w12, #1] + add x13, x13, x2, LSL #2 + ld1w { z24.s }, p0/Z, [x20] + KAI_ASM_INST(0x25344500) // psel p0.b, p1.b/Z, p8.b[w12, #2] + add x20, x20, x1, LSL #2 + ld1w { z25.s }, p0/Z, [x20] + KAI_ASM_INST(0x253c4500) // psel p0.b, p1.b/Z, p8.b[w12, #3] + add x20, x20, x1, LSL #2 + mov x12, #0x4 + ld1w { z26.s }, p0/Z, [x20] + KAI_ASM_INST(0x25244500) // psel p0.b, p1.b/Z, p8.b[w12] + add x20, x20, x1, LSL #2 + KAI_ASM_INST(0xc13e1ae0) // fmla za.s[x8, 0], { z23.s-z26.s }, z14.s + ld1w { z27.s }, p0/Z, [x20] + KAI_ASM_INST(0x252c4500) // psel p0.b, p1.b/Z, p8.b[w12, #1] + add x20, x20, x1, LSL #2 + ld1w { z28.s }, p0/Z, [x20] + KAI_ASM_INST(0xc13c1b00) // fmla za.s[x8, 0], { z24.s-z27.s }, z12.s + KAI_ASM_INST(0xc1341b20) // fmla za.s[x8, 0], { z25.s-z28.s }, z4.s +KAI_ASM_LABEL(label_12) // Padded: 1 priming loads + mov x12, #0x0 + add x20, x13, x1, LSL #2 + KAI_ASM_INST(0x25244500) // psel p0.b, p1.b/Z, p8.b[w12] + ld1w { z25.s }, p0/Z, [x13] + KAI_ASM_INST(0x252c4500) // psel p0.b, p1.b/Z, p8.b[w12, #1] + add x13, x13, x2, LSL #2 + ld1w { z26.s }, p0/Z, [x20] + KAI_ASM_INST(0x25344500) // psel p0.b, p1.b/Z, p8.b[w12, #2] + add x20, x20, x1, LSL #2 + ld1w { z27.s }, p0/Z, [x20] + KAI_ASM_INST(0x253c4500) // psel p0.b, p1.b/Z, p8.b[w12, #3] + add x20, x20, x1, LSL #2 + mov x12, #0x4 + ld1w { z28.s }, p0/Z, [x20] + KAI_ASM_INST(0x25244500) // psel p0.b, p1.b/Z, p8.b[w12] + add x20, x20, x1, LSL #2 + KAI_ASM_INST(0xc13f1b20) // fmla za.s[x8, 0], { z25.s-z28.s }, z15.s + ld1w { z29.s }, p0/Z, [x20] + KAI_ASM_INST(0x252c4500) // psel p0.b, p1.b/Z, p8.b[w12, #1] + add x20, x20, x1, LSL #2 + KAI_ASM_INST(0xc13e1b21) // fmla za.s[x8, 1], { z25.s-z28.s }, z14.s + ld1w { z30.s }, p0/Z, [x20] + KAI_ASM_INST(0xc13d1b40) // fmla za.s[x8, 0], { z26.s-z29.s }, z13.s + KAI_ASM_INST(0xc13c1b41) // fmla za.s[x8, 1], { z26.s-z29.s }, z12.s + KAI_ASM_INST(0xc1351b60) // fmla za.s[x8, 0], { z27.s-z30.s }, z5.s + KAI_ASM_INST(0xc1341b61) // fmla za.s[x8, 1], { z27.s-z30.s }, z4.s +KAI_ASM_LABEL(label_13) // Padded: 0 priming loads + cbz x14, label_16 + mov x12, #0x0 + add x20, x13, x1, LSL #2 + KAI_ASM_INST(0x25244500) // psel p0.b, p1.b/Z, p8.b[w12] + sub x14, x14, #0x1 + sub x11, x11, #0x1 + cmp x14, x11 + ld1w { z20.s }, p0/Z, [x13] + KAI_ASM_INST(0x252c4500) // psel p0.b, p1.b/Z, p8.b[w12, #1] + csel x21, x14, x11, LT + add x13, x13, x2, LSL #2 + sub x11, x11, x21 + ld1w { z21.s }, p0/Z, [x20] + KAI_ASM_INST(0x25344500) // psel p0.b, p1.b/Z, p8.b[w12, #2] + add x20, x20, x1, LSL #2 + ld1w { z22.s }, p0/Z, [x20] + KAI_ASM_INST(0x253c4500) // psel p0.b, p1.b/Z, p8.b[w12, #3] + add x20, x20, x1, LSL #2 + mov x12, #0x4 + ld1w { z23.s }, p0/Z, [x20] + KAI_ASM_INST(0x25244500) // psel p0.b, p1.b/Z, p8.b[w12] + add x20, x20, x1, LSL #2 + ld1w { z24.s }, p0/Z, [x20] + KAI_ASM_INST(0x252c4500) // psel p0.b, p1.b/Z, p8.b[w12, #1] + add x20, x20, x1, LSL #2 + ld1w { z25.s }, p0/Z, [x20] + cbz x21, label_15 +KAI_ASM_LABEL(label_14) // Padded: Main loop + mov x12, #0x0 + KAI_ASM_INST(0xc13b1a80) // fmla za.s[x8, 0], { z20.s-z23.s }, z11.s + add x20, x13, x1, LSL #2 + KAI_ASM_INST(0x25244500) // psel p0.b, p1.b/Z, p8.b[w12] + KAI_ASM_INST(0xc13f1a81) // fmla za.s[x8, 1], { z20.s-z23.s }, z15.s + subs x21, x21, #0x1 + KAI_ASM_INST(0xc13e1a82) // fmla za.s[x8, 2], { z20.s-z23.s }, z14.s + ld1w { z20.s }, p0/Z, [x13] + KAI_ASM_INST(0x252c4500) // psel p0.b, p1.b/Z, p8.b[w12, #1] + KAI_ASM_INST(0xc1301aa0) // fmla za.s[x8, 0], { z21.s-z24.s }, z0.s + add x13, x13, x2, LSL #2 + KAI_ASM_INST(0xc13d1aa1) // fmla za.s[x8, 1], { z21.s-z24.s }, z13.s + KAI_ASM_INST(0xc13c1aa2) // fmla za.s[x8, 2], { z21.s-z24.s }, z12.s + ld1w { z21.s }, p0/Z, [x20] + KAI_ASM_INST(0x25344500) // psel p0.b, p1.b/Z, p8.b[w12, #2] + add x20, x20, x1, LSL #2 + KAI_ASM_INST(0xc1371ac0) // fmla za.s[x8, 0], { z22.s-z25.s }, z7.s + KAI_ASM_INST(0xc1351ac1) // fmla za.s[x8, 1], { z22.s-z25.s }, z5.s + KAI_ASM_INST(0xc1341ac2) // fmla za.s[x8, 2], { z22.s-z25.s }, z4.s + ld1w { z22.s }, p0/Z, [x20] + KAI_ASM_INST(0x253c4500) // psel p0.b, p1.b/Z, p8.b[w12, #3] + add x20, x20, x1, LSL #2 + mov x12, #0x4 + KAI_ASM_INST(0xc0060c1c) // mova { z28.d-z31.d }, za.d[x8, #0] + add x8, x8, #0x1 + ld1w { z23.s }, p0/Z, [x20] + KAI_ASM_INST(0x25244500) // psel p0.b, p1.b/Z, p8.b[w12] + add x20, x20, x1, LSL #2 + KAI_ASM_INST(0xc0040e02) // mova za.d[x8, #2], { z16.d-z19.d } + ld1w { z24.s }, p0/Z, [x20] + KAI_ASM_INST(0x252c4500) // psel p0.b, p1.b/Z, p8.b[w12, #1] + add x20, x20, x1, LSL #2 + KAI_ASM_INST(0xc1a9c87c) // fclamp { z28.s-z31.s }, z3.s, z9.s + ld1w { z25.s }, p0/Z, [x20] + st1w { z28.s }, p1, [x9] + add x9, x9, x27, LSL #2 + st1w { z29.s }, p1, [x28] + add x28, x28, x26, LSL #2 + st1w { z30.s }, p1, [x25] + add x25, x25, x23, LSL #2 + st1w { z31.s }, p1, [x24] + add x24, x24, x22, LSL #2 + bgt label_14 +KAI_ASM_LABEL(label_15) // Main loop tail + KAI_ASM_INST(0xc13b1a80) // fmla za.s[x8, 0], { z20.s-z23.s }, z11.s + KAI_ASM_INST(0xc13f1a81) // fmla za.s[x8, 1], { z20.s-z23.s }, z15.s + KAI_ASM_INST(0xc13e1a82) // fmla za.s[x8, 2], { z20.s-z23.s }, z14.s + KAI_ASM_INST(0xc1301aa0) // fmla za.s[x8, 0], { z21.s-z24.s }, z0.s + KAI_ASM_INST(0xc13d1aa1) // fmla za.s[x8, 1], { z21.s-z24.s }, z13.s + KAI_ASM_INST(0xc13c1aa2) // fmla za.s[x8, 2], { z21.s-z24.s }, z12.s + KAI_ASM_INST(0xc1371ac0) // fmla za.s[x8, 0], { z22.s-z25.s }, z7.s + KAI_ASM_INST(0xc1351ac1) // fmla za.s[x8, 1], { z22.s-z25.s }, z5.s + KAI_ASM_INST(0xc1341ac2) // fmla za.s[x8, 2], { z22.s-z25.s }, z4.s + KAI_ASM_INST(0xc0060c14) // mova { z20.d-z23.d }, za.d[x8, #0] + add x8, x8, #0x1 + KAI_ASM_INST(0xc0040e02) // mova za.d[x8, #2], { z16.d-z19.d } + KAI_ASM_INST(0xc1a9c874) // fclamp { z20.s-z23.s }, z3.s, z9.s + st1w { z20.s }, p1, [x9] + add x9, x9, x27, LSL #2 + st1w { z21.s }, p1, [x28] + add x28, x28, x26, LSL #2 + st1w { z22.s }, p1, [x25] + add x25, x25, x23, LSL #2 + st1w { z23.s }, p1, [x24] + add x24, x24, x22, LSL #2 +KAI_ASM_LABEL(label_16) // Main loop skip tail + cbz x11, label_18 +KAI_ASM_LABEL(label_17) // Right padding loop + KAI_ASM_INST(0xc0060c04) // mova { z4.d-z7.d }, za.d[x8, #0] + add x8, x8, #0x1 + subs x11, x11, #0x1 + KAI_ASM_INST(0xc0040e02) // mova za.d[x8, #2], { z16.d-z19.d } + KAI_ASM_INST(0xc1a9c864) // fclamp { z4.s-z7.s }, z3.s, z9.s + st1w { z4.s }, p1, [x9] + add x9, x9, x27, LSL #2 + st1w { z5.s }, p1, [x28] + add x28, x28, x26, LSL #2 + st1w { z6.s }, p1, [x25] + add x25, x25, x23, LSL #2 + st1w { z7.s }, p1, [x24] + add x24, x24, x22, LSL #2 + bgt label_17 +KAI_ASM_LABEL(label_18) // End + ldr x20, [x0, #0x18] + incw x15 + whilelt p1.s, x15, x16 + incb x20, ALL, MUL #9 + str x20, [x0, #0x18] + ldr x21, [x0, #0x40] + ldr x20, [x0, #0x0] + add x20, x20, x21, LSL #2 + str x20, [x0, #0x0] + ldr x25, [x0, #0x30] + ldr x24, [x0, #0x48] + ldp x23, x22, [x25, #0x0] + ldp x21, x20, [x24, #0x0] + add x23, x23, x21, LSL #2 + add x22, x22, x20, LSL #2 + stp x23, x22, [x25, #0x0] + ldp x23, x22, [x25, #0x10] + ldp x21, x20, [x24, #0x10] + add x23, x23, x21, LSL #2 + add x22, x22, x20, LSL #2 + stp x23, x22, [x25, #0x10] + b.any label_1 + KAI_ASM_INST(0xd503467f) // SMSTOP + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp d8, d9, [sp, 72] + ldp d10, d11, [sp, 88] + ldp d12, d13, [sp, 104] + ldp d14, d15, [sp, 120] + ldp x20, x21, [sp], 144 + ret + KAI_ASM_FUNCTION_END(kai_kernel_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla) + + KAI_ASM_END -- GitLab From e5134e8803b59487270f3768cf77c776217e5ebd Mon Sep 17 00:00:00 2001 From: Mohammed Suhail Munshi Date: Tue, 29 Jul 2025 11:37:46 +0100 Subject: [PATCH 6/6] Fix Bazel builds for dw conv Signed-off-by: Mohammed Suhail Munshi --- BUILD.bazel | 3 +- kai/ukernels/dw_conv/BUILD.bazel | 52 ++++++++++++++++++++++++++++ test/reference/depthwise_conv.cpp | 6 ++++ test/tests/depthwise_planar_test.cpp | 1 - 4 files changed, 60 insertions(+), 2 deletions(-) create mode 100644 kai/ukernels/dw_conv/BUILD.bazel diff --git a/BUILD.bazel b/BUILD.bazel index c4abdc25..7dbe5166 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates # # SPDX-License-Identifier: Apache-2.0 # @@ -36,6 +36,7 @@ kai_c_library( name = "kleidiai", visibility = ["//visibility:public"], deps = [ + "//kai/ukernels/dw_conv", "//kai/ukernels/matmul", ], ) diff --git a/kai/ukernels/dw_conv/BUILD.bazel b/kai/ukernels/dw_conv/BUILD.bazel new file mode 100644 index 00000000..baeb21f0 --- /dev/null +++ b/kai/ukernels/dw_conv/BUILD.bazel @@ -0,0 +1,52 @@ +# +# SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates +# +# SPDX-License-Identifier: Apache-2.0 +# + +load( + "//:kai_defs.bzl", + "kai_c_library", + "kai_cpu_sme2", +) + +package(default_visibility = ["//visibility:private"]) + +# buildifier: keep sorted +SME2_KERNELS = [ + "pack/kai_rhs_dw_conv_pack_x32p1vl_x32_sme", +] + +SME2_KERNELS_ASM = [ + "depthwise_planar_f32_f32p_f32p/kai_dw_conv_f32_f32_f32p1vl_3x3_s1_4xc_planar_sme2_mla", +] + +kai_c_library( + name = "sme2_impl", + srcs = [ukernel + ".c" for ukernel in SME2_KERNELS], + cpu_uarch = kai_cpu_sme2(), + textual_hdrs = [ukernel + ".h" for ukernel in SME2_KERNELS], +) + +kai_c_library( + name = "sme2_impl_asm", + srcs = [ukernel + "_asm.S" for ukernel in SME2_KERNELS_ASM] + [ukernel + ".c" for ukernel in SME2_KERNELS_ASM], + cpu_uarch = kai_cpu_sme2(), + textual_hdrs = [ukernel + ".h" for ukernel in SME2_KERNELS_ASM], +) + +kai_c_library( + name = "interface", + textual_hdrs = glob(["**/*_interface.h"]), + visibility = ["//visibility:public"], +) + +kai_c_library( + name = "dw_conv", + visibility = ["//visibility:public"], + deps = [ + ":interface", + ":sme2_impl", + ":sme2_impl_asm", + ], +) diff --git a/test/reference/depthwise_conv.cpp b/test/reference/depthwise_conv.cpp index 9bfbad9b..845b9fac 100644 --- a/test/reference/depthwise_conv.cpp +++ b/test/reference/depthwise_conv.cpp @@ -70,4 +70,10 @@ Buffer depthwise_reference( return dst; } +// Explicit template +template Buffer depthwise_reference( + const size_t batches, const size_t in_height, const size_t in_width, const size_t channels, + const size_t filter_height, const size_t filter_width, const void* feature_map, const void* weights, + const void* bias, const Padding2D& pad); + } // namespace kai::test diff --git a/test/tests/depthwise_planar_test.cpp b/test/tests/depthwise_planar_test.cpp index 7c71c28d..3164e51d 100644 --- a/test/tests/depthwise_planar_test.cpp +++ b/test/tests/depthwise_planar_test.cpp @@ -25,7 +25,6 @@ #include "test/common/round.hpp" #include "test/common/sme.hpp" #include "test/reference/clamp.hpp" -#include "test/reference/depthwise_conv.cpp" #include "test/reference/depthwise_conv.hpp" #include "test/reference/fill.hpp" -- GitLab