diff --git a/CMakeLists.txt b/CMakeLists.txt index 3ce4a125f67f99c5ab637c991ba1861943bd4cca..c9f6566f2f0904a750358dbc19aba8c601a0bc48 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.26.0) set(CMAKE_CXX_STANDARD 17) -project(cme-usecase) +project(ai-camera-pipelines) set(ENABLE_SME2 1 CACHE STRING "Flag to enable SME2/SSVE functions") set(WITH_OPENMP 1 CACHE STRING "Flag to enable OpenMP parallelization") @@ -10,6 +10,9 @@ set(ARMNN_TFLITE_PARSER 1 CACHE STRING "Flag to enable/disable ARMNN TFLite Pars set(CINEMATIC_MODE 1 CACHE STRING "Flag to add subdirectory cinematic_mode to cmake build, default ON") set(LOW_LIGHT_IMAGE_ENHANCEMENT 1 CACHE STRING "Flag to add subdirectory low_light_image_enhancement to cmake build, default ON") +set(DENOISER 1 CACHE STRING "Flag to add subdirectory denoiser to cmake build, default ON") +set(DENOISER_ZA16 0 CACHE STRING "Flag to add mop4 denoiser variant to cmake build, default OFF") +set(DENOISER_MOP4 0 CACHE STRING "Flag to add za16 denoiser variant to cmake build, default OFF") if (ENABLE_SME2) set(KLEIDICV_ENABLE_SME2 1) diff --git a/cmake/set_isa_target_compile_options.cmake b/cmake/set_isa_target_compile_options.cmake index ce7bb61d6db025db33c60f3a7247292a8770754f..7562238faf256f7f49dc684174cccd640486deb8 100644 --- a/cmake/set_isa_target_compile_options.cmake +++ b/cmake/set_isa_target_compile_options.cmake @@ -9,6 +9,8 @@ function(set_isa_target_compile_options TARGET) ) target_compile_definitions(${TARGET} PRIVATE ENABLE_SME2=${ENABLE_SME2} + ARM_STREAMING=__arm_streaming + ARM_STREAMING_COMPATIBLE=__arm_streaming_compatible ) else () check_cxx_compiler_flag("-march=armv9-a+sve2" SUPPORTS_SVE2) @@ -18,5 +20,9 @@ function(set_isa_target_compile_options TARGET) target_compile_options(${TARGET} PRIVATE "-march=armv9-a+sve2" ) + target_compile_definitions(${TARGET} PRIVATE + ARM_STREAMING= + ARM_STREAMING_COMPATIBLE= + ) endif() endfunction() \ No newline at end of file diff --git a/usecase/CMakeLists.txt b/usecase/CMakeLists.txt index c2b3b4c5ef8869263e00c5b497c45080469b3826..75c41b8cb3279cd739ae5093fb0ede2bd8161fc8 100644 --- a/usecase/CMakeLists.txt +++ b/usecase/CMakeLists.txt @@ -4,3 +4,6 @@ endif() if (LOW_LIGHT_IMAGE_ENHANCEMENT) add_subdirectory(low_light_image_enhancement) endif() +if (DENOISER) + add_subdirectory(denoiser) +endif() \ No newline at end of file diff --git a/usecase/denoiser/CMakeLists.txt b/usecase/denoiser/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..82bfe4a4ce1ecf9d0c9dab72f188facbc25e07ac --- /dev/null +++ b/usecase/denoiser/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(core) \ No newline at end of file diff --git a/usecase/denoiser/core/CMakeLists.txt b/usecase/denoiser/core/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..18b668b42623b01548d4d7d43ccdb66041f37dd8 --- /dev/null +++ b/usecase/denoiser/core/CMakeLists.txt @@ -0,0 +1,24 @@ +add_library(collapsenet + collapse_net.h + collapse_net_api.cpp + collapse_net_sme2.cpp +) + +set_isa_target_compile_options(collapsenet) + +if (DENOISER_MOP4) + target_compile_options(collapsenet PRIVATE + DENOISER_MOP4=1 +) +endif () + +if (DENOISER_ZA16) + target_compile_options(collapsenet PRIVATE + DENOISER_ZA16=1 + ) +endif () + +target_link_libraries(collapsenet + PRIVATE + acle +) \ No newline at end of file diff --git a/usecase/denoiser/core/cc.h b/usecase/denoiser/core/cc.h new file mode 100644 index 0000000000000000000000000000000000000000..172c14da986a0fc948aee046bbf86f33943260b8 --- /dev/null +++ b/usecase/denoiser/core/cc.h @@ -0,0 +1,112 @@ +#include "arm_sve.h" + +// --- Colour conversion utility functions in SVE Intrinsic --- + +#if ENABLE_SME2 +#define ARM_STREAMING_COMPATIBLE __arm_streaming_compatible +#else +#define ARM_STREAMING_COMPATIBLE +#endif + +inline svfloat16x4_t +rggb_to_uggv_f16(const svbool_t pg, + const svfloat16x4_t rggb) ARM_STREAMING_COMPATIBLE { + svfloat16_t mean_g = svdiv_n_f16_x( + pg, svadd_f16_x(pg, svget4_f16(rggb, 1), svget4_f16(rggb, 2)), 2); + return svcreate4_f16(svsub_f16_x(pg, svget4_f16(rggb, 0), mean_g), + svget4_f16(rggb, 1), svget4_f16(rggb, 2), + svsub_f16_x(pg, svget4_f16(rggb, 3), mean_g)); +} + +inline svfloat32x4_t +rggb_to_uggv_f32(const svbool_t pg, + const svfloat32x4_t rggb) ARM_STREAMING_COMPATIBLE { + svfloat32_t mean_g = svdiv_n_f32_x( + pg, svadd_f32_x(pg, svget4_f32(rggb, 1), svget4_f32(rggb, 2)), 2); + return svcreate4_f32(svsub_f32_x(pg, svget4_f32(rggb, 0), mean_g), + svget4_f32(rggb, 1), svget4_f32(rggb, 2), + svsub_f32_x(pg, svget4_f32(rggb, 3), mean_g)); +} + +inline svfloat16x4_t +uggv_to_rggb_f16(const svbool_t pg, + const svfloat16x4_t uggv) ARM_STREAMING_COMPATIBLE { + svfloat16_t mean_g = svdiv_n_f16_x( + pg, svadd_f16_x(pg, svget4_f16(uggv, 1), svget4_f16(uggv, 2)), 2); + return svcreate4_f16(svadd_f16_x(pg, svget4_f16(uggv, 0), mean_g), + svget4_f16(uggv, 1), svget4_f16(uggv, 2), + svadd_f16_x(pg, svget4_f16(uggv, 3), mean_g)); +} + +inline svfloat32x4_t +uggv_to_rggb_f32(const svbool_t pg, + const svfloat32x4_t uggv) ARM_STREAMING_COMPATIBLE { + svfloat32_t mean_g = svdiv_n_f32_x( + pg, svadd_f32_x(pg, svget4_f32(uggv, 1), svget4_f32(uggv, 2)), 2); + return svcreate4_f32(svadd_f32_x(pg, svget4_f32(uggv, 0), mean_g), + svget4_f32(uggv, 1), svget4_f32(uggv, 2), + svadd_f32_x(pg, svget4_f32(uggv, 3), mean_g)); +} + +inline svfloat16x3_t rgb_to_yuv(const svbool_t pg, const svfloat16x3_t rgb) + ARM_STREAMING_COMPATIBLE { + + svfloat16_t y = svdup_n_f16_x(pg, (float16_t)0.0); + y = svmla_n_f16_x(pg, y, svget3_f16(rgb, 0), (float16_t)0.299); + y = svmla_n_f16_x(pg, y, svget3_f16(rgb, 1), (float16_t)0.587); + y = svmla_n_f16_x(pg, y, svget3_f16(rgb, 2), (float16_t)0.114); + + svfloat16_t u = svdup_n_f16_x(pg, (float16_t)0.0); + u = svmla_n_f16_x(pg, u, svget3_f16(rgb, 0), (float16_t)-0.14713); + u = svmla_n_f16_x(pg, u, svget3_f16(rgb, 1), (float16_t)-0.28886); + u = svmla_n_f16_x(pg, u, svget3_f16(rgb, 2), (float16_t)0.436); + + svfloat16_t v = svdup_n_f16_x(pg, (float16_t)0.0); + v = svmla_n_f16_x(pg, v, svget3_f16(rgb, 0), (float16_t)0.615); + v = svmla_n_f16_x(pg, v, svget3_f16(rgb, 1), (float16_t)-0.51499); + v = svmla_n_f16_x(pg, v, svget3_f16(rgb, 2), (float16_t)-0.10001); + + return svcreate3_f16(y, u, v); +} + +inline svfloat16x3_t square_and_ccm(const svbool_t pg, const svfloat16x3_t rgb) + ARM_STREAMING_COMPATIBLE { + svfloat16x3_t rgb_sq = + svcreate3_f16(svmul_f16_x(pg, svget3_f16(rgb, 0), svget3_f16(rgb, 0)), + svmul_f16_x(pg, svget3_f16(rgb, 1), svget3_f16(rgb, 1)), + svmul_f16_x(pg, svget3_f16(rgb, 2), svget3_f16(rgb, 2))); + svfloat16_t r = svdup_n_f16_x(pg, 0); + r = svmla_n_f16_x(pg, r, svget3_f16(rgb_sq, 0), 1.6143579483032227); + r = svmla_n_f16_x(pg, r, svget3_f16(rgb_sq, 1), -0.5796861052513123); + r = svmla_n_f16_x(pg, r, svget3_f16(rgb_sq, 2), -0.03467187657952309); + svfloat16_t g = svdup_n_f16_x(pg, 0); + g = svmla_n_f16_x(pg, g, svget3_f16(rgb_sq, 0), -0.14471182227134705); + g = svmla_n_f16_x(pg, g, svget3_f16(rgb_sq, 1), 1.3473817110061646); + g = svmla_n_f16_x(pg, g, svget3_f16(rgb_sq, 2), -0.20266985893249512); + svfloat16_t b = svdup_n_f16_x(pg, 0); + b = svmla_n_f16_x(pg, b, svget3_f16(rgb_sq, 0), 0.029018063098192215); + b = svmla_n_f16_x(pg, b, svget3_f16(rgb_sq, 1), -0.8362194895744324); + b = svmla_n_f16_x(pg, b, svget3_f16(rgb_sq, 2), 1.8072013854980469); + return svcreate3_f16(r, g, b); +} + +inline svfloat32x3_t square_and_ccm(const svbool_t pg, const svfloat32x3_t rgb) + ARM_STREAMING_COMPATIBLE { + svfloat32x3_t rgb_sq = + svcreate3_f32(svmul_f32_x(pg, svget3_f32(rgb, 0), svget3_f32(rgb, 0)), + svmul_f32_x(pg, svget3_f32(rgb, 1), svget3_f32(rgb, 1)), + svmul_f32_x(pg, svget3_f32(rgb, 2), svget3_f32(rgb, 2))); + svfloat32_t r = svdup_n_f32_x(pg, 0); + r = svmla_n_f32_x(pg, r, svget3_f32(rgb_sq, 0), 1.6143579483032227); + r = svmla_n_f32_x(pg, r, svget3_f32(rgb_sq, 1), -0.5796861052513123); + r = svmla_n_f32_x(pg, r, svget3_f32(rgb_sq, 2), -0.03467187657952309); + svfloat32_t g = svdup_n_f32_x(pg, 0); + g = svmla_n_f32_x(pg, g, svget3_f32(rgb_sq, 0), -0.14471182227134705); + g = svmla_n_f32_x(pg, g, svget3_f32(rgb_sq, 1), 1.3473817110063246); + g = svmla_n_f32_x(pg, g, svget3_f32(rgb_sq, 2), -0.20266985893249512); + svfloat32_t b = svdup_n_f32_x(pg, 0); + b = svmla_n_f32_x(pg, b, svget3_f32(rgb_sq, 0), 0.029018063098192215); + b = svmla_n_f32_x(pg, b, svget3_f32(rgb_sq, 1), -0.8362194895744324); + b = svmla_n_f32_x(pg, b, svget3_f32(rgb_sq, 2), 1.8072013854980469); + return svcreate3_f32(r, g, b); +} \ No newline at end of file diff --git a/usecase/denoiser/core/collapse_net.h b/usecase/denoiser/core/collapse_net.h new file mode 100644 index 0000000000000000000000000000000000000000..86fdd590d987152affcb77ac5c6b989d44886de0 --- /dev/null +++ b/usecase/denoiser/core/collapse_net.h @@ -0,0 +1,217 @@ +#if ENABLE_SME2 + +#include "arm_sme.h" +#include +#include +#include + +#define ARM_INOUT_ZA __arm_inout("za") +#define ARM_NEW_ZA __arm_new("za") + +struct branch_weights_float16 { + float16_t *weight_conv; + float16_t *bias_conv; + float16_t *weight_pointwise_conv0; + float16_t *bias_pointwise_conv0; + float16_t *weight_pointwise_conv1; + float16_t *bias_pointwise_conv1; + float16_t *weight_blur; + float16_t *bias_blur; +}; + +struct branch_weights_float32 { + float *weight_conv; + float *bias_conv; + float *weight_pointwise_conv0; + float *bias_pointwise_conv0; + float *weight_pointwise_conv1; + float *bias_pointwise_conv1; + float *weight_blur; + float *bias_blur; +}; + +struct branch_weights_interp_unfused_float16 { + float16_t *weight_conv; + float16_t *bias_conv; + float16_t *weight_pointwise_conv0; + float16_t *bias_pointwise_conv0; + float16_t *weight_pointwise_conv1; + float16_t *bias_pointwise_conv1; + float16_t *weight_interp; + float16_t *bias_interp; + float16_t *weight_blur; + float16_t *bias_blur; +}; + +struct branch_weights_interp_collect_unfused_float16 { + float16_t *weight_conv; + float16_t *bias_conv; + float16_t *weight_pointwise_conv0; + float16_t *bias_pointwise_conv0; + float16_t *weight_pointwise_conv1; + float16_t *bias_pointwise_conv1; + float16_t *weight_interp; + float16_t *bias_interp; + float16_t *weight_collect; + float16_t *weight_blur; + float16_t *bias_blur; +}; + +struct collapsenet_weights_float16 { + branch_weights_float16 branch[2][4]; + uint64_t kernel_size; +}; + +struct collapsenet_weights_float32 { + branch_weights_float32 branch[2][4]; + uint64_t kernel_size; +}; + +struct collapsenet_v1_1_weights_float16 { + branch_weights_float16 branch_u[2]; + branch_weights_float16 branch_g1[3]; + branch_weights_float16 branch_g2[3]; + branch_weights_float16 branch_v[2]; + uint64_t kernel_size; +}; + +struct collapsenet_v1_1_weights_float32 { + branch_weights_float32 branch_u[2]; + branch_weights_float32 branch_g1[3]; + branch_weights_float32 branch_g2[3]; + branch_weights_float32 branch_v[2]; + uint64_t kernel_size; +}; + +struct collapsenet_v1_1_weights_interp_unfused_float16 { + branch_weights_interp_unfused_float16 branch_u[2]; + branch_weights_interp_unfused_float16 branch_g1[3]; + branch_weights_interp_unfused_float16 branch_g2[3]; + branch_weights_interp_unfused_float16 branch_v[2]; + uint64_t kernel_size; +}; + +struct collapsenet_v1_1_weights_interp_collect_unfused_float16 { + branch_weights_interp_collect_unfused_float16 branch_u[2]; + branch_weights_interp_collect_unfused_float16 branch_g1[3]; + branch_weights_interp_collect_unfused_float16 branch_g2[3]; + branch_weights_interp_collect_unfused_float16 branch_v[2]; + uint64_t kernel_size; +}; + +namespace acle { + +void impl_collapsenet(float16_t *src, collapsenet_weights_float16 &weights, + float16_t *dst, const size_t H, const size_t W, + const size_t O); +void impl_collapsenet(float *src, collapsenet_weights_float32 &weights, + float *dst, const size_t H, const size_t W, + const size_t O); +void impl_collapsenet_v1_1(float16_t *src, + collapsenet_v1_1_weights_float16 &weights, + float16_t *dst, const size_t H, const size_t W, + const size_t O); +void impl_collapsenet_v1_1(float *src, + collapsenet_v1_1_weights_float32 &weights, + float *dst, const size_t H, const size_t W, + const size_t O); +void impl_collapsenet_v1_1_uggv_out_wo_pad_interleave( + float16_t *src, collapsenet_v1_1_weights_float16 &weights, float16_t *dst, + const size_t H, const size_t W, const size_t O); +void impl_collapsenet_v1_1_uggv_out_wo_pad_interleave( + float *src, collapsenet_v1_1_weights_float32 &weights, float *dst, + const size_t H, const size_t W, const size_t O); +void impl_collapsenet_v1_1_uggv_out_wo_pad_interleave_interp_unfused( + float16_t *src, collapsenet_v1_1_weights_interp_unfused_float16 &weights, + float16_t *dst, const size_t H, const size_t W, const size_t O); +void impl_collapsenet_v1_1_uggv_out_wo_pad_interleave_interp_collect_unfused( + float16_t *src, + collapsenet_v1_1_weights_interp_collect_unfused_float16 &weights, + float16_t *dst, const size_t H, const size_t W, const size_t O); +void impl_collapsenet_v1_1_uggv_out_wo_pad_interleave_non_widening( + float16_t *src, collapsenet_v1_1_weights_float16 &weights, float16_t *dst, + const size_t H, const size_t W, const size_t O); +void impl_collapsenet_v1_1_uggv_out_wo_pad_interleave_non_widening_mop4( + float16_t *src, collapsenet_v1_1_weights_float16 &weights, float16_t *dst, + const size_t H, const size_t W, const size_t O); + +namespace sme2 { + +void interleave_channels(float *src, float *dst, const size_t N) ARM_STREAMING; +void interleave_channels(float16_t *src, float16_t *dst, + const size_t N) ARM_STREAMING; +void interleave_uggv_to_rggb_channels(float *src, float *dst, + const size_t N) ARM_STREAMING; +void interleave_uggv_to_rggb_channels(float16_t *src, float16_t *dst, + const size_t N) ARM_STREAMING; +void deinterleave_and_pad_channels(float *src, float *dst, const size_t H, + const size_t W, + const size_t pad) ARM_STREAMING; +void deinterleave_and_pad_channels(float16_t *src, float16_t *dst, + const size_t H, const size_t W, + const size_t pad) ARM_STREAMING; +void impl_2x2_bilinear_downsample(const float16_t *src, float16_t *dst, + const uint64_t src_height, + const uint64_t src_width, + const uint64_t src_pad, + const uint64_t dst_pad) ARM_STREAMING; +void impl_2x2_bilinear_downsample(const float *src, float *dst, + const uint64_t src_height, + const uint64_t src_width, + const uint64_t src_pad, + const uint64_t dst_pad) ARM_STREAMING; +void impl_2x2_bilinear_upsample(const float16_t *src, float16_t *dst, + const uint64_t src_height, + const uint64_t src_width, + const uint64_t src_pad, + const uint64_t dst_pad) ARM_STREAMING; +void impl_2x2_bilinear_upsample(const float *src, float *dst, + const uint64_t src_height, + const uint64_t src_width, + const uint64_t src_pad, + const uint64_t dst_pad) ARM_STREAMING; +void impl_2x2_bilinear_upsample_and_merge( + const float16_t *input, const float16_t *input_lr, + const float16_t *output_lr, float16_t *output, const uint64_t input_height, + const uint64_t input_width, const uint64_t input_pad, + const uint64_t output_pad) ARM_STREAMING; +void impl_2x2_bilinear_upsample_and_merge( + const float *input, const float *input_lr, const float *output_lr, + float *output, const uint64_t input_height, const uint64_t input_width, + const uint64_t input_pad, const uint64_t output_pad) ARM_STREAMING; + +ARM_NEW_ZA +void impl_collapsenet_channel_branch(float16_t *src, + branch_weights_float16 *weights, + float16_t *dst, const size_t H, + const size_t W, const size_t O, + const size_t K, const size_t src_pad, + const size_t dst_pad) ARM_STREAMING; +ARM_NEW_ZA void impl_collapsenet_channel_branch( + float *src, branch_weights_float32 *weights, float *dst, const size_t H, + const size_t W, const size_t O, const size_t K, const size_t src_pad, + const size_t dst_pad) ARM_STREAMING; +ARM_NEW_ZA +void impl_collapsenet_channel_branch_interp_unfused( + float16_t *src, branch_weights_interp_unfused_float16 *weights, + float16_t *dst, const size_t H, const size_t W, const size_t O, + const size_t K, const size_t src_pad, const size_t dst_pad) ARM_STREAMING; +ARM_NEW_ZA +void impl_collapsenet_channel_branch_interp_collect_unfused( + float16_t *src, branch_weights_interp_collect_unfused_float16 *weights, + float16_t *dst, const size_t H, const size_t W, const size_t O, + const size_t K, const size_t src_pad, const size_t dst_pad) ARM_STREAMING; +ARM_NEW_ZA void impl_collapsenet_channel_branch_non_widening( + float16_t *src, branch_weights_float16 *weights, float16_t *dst, + const size_t H, const size_t W, const size_t O, const size_t K, + const size_t src_pad, const size_t dst_pad) ARM_STREAMING; +ARM_NEW_ZA void impl_collapsenet_channel_branch_non_widening_mop4( + float16_t *src, branch_weights_float16 *weights, float16_t *dst, + const size_t H, const size_t W, const size_t O, const size_t K, + const size_t src_pad, const size_t dst_pad) ARM_STREAMING; + +} // namespace sme2 + +} // namespace acle + +#endif \ No newline at end of file diff --git a/usecase/denoiser/core/collapse_net_api.cpp b/usecase/denoiser/core/collapse_net_api.cpp new file mode 100644 index 0000000000000000000000000000000000000000..dee4cb81aaeb1c464e663ca7732f599bff595255 --- /dev/null +++ b/usecase/denoiser/core/collapse_net_api.cpp @@ -0,0 +1,1474 @@ +#if ENABLE_SME2 + +#include "collapse_net.h" + +namespace acle { + +void impl_collapsenet(float *src, collapsenet_weights_float32 &weights, + float *dst, const size_t H, const size_t W, + const size_t O) { + + // Compute padding + uint64_t pad_2x = weights.kernel_size - 1; + uint64_t pad = pad_2x / 2; + + // Compute number of elements + uint64_t N = H * W; + uint64_t N_pad = (H + pad_2x) * (W + pad_2x); + + // Allocate buffers for intermediate data + float *output_blk1 = new float[N_pad](); + float *output_blk2 = new float[N * 4]; + + // --- Deinterleave channels and pad --- + float *padded = new float[N_pad * 4](); + acle::sme2::deinterleave_and_pad_channels(src, padded, H, W, pad); + + // --- Run CollapseNet blocks --- + + // R-Channel + { + acle::sme2::impl_collapsenet_channel_branch(padded, &weights.branch[0][0], + output_blk1, H, W, O, + weights.kernel_size, pad, pad); + acle::sme2::impl_collapsenet_channel_branch( + output_blk1, &weights.branch[1][0], output_blk2, H, W, O, + weights.kernel_size, pad, 0); + } + + // G0-Channel + { + acle::sme2::impl_collapsenet_channel_branch( + padded + N_pad, &weights.branch[0][1], output_blk1, H, W, O, + weights.kernel_size, pad, pad); + acle::sme2::impl_collapsenet_channel_branch( + output_blk1, &weights.branch[1][1], output_blk2 + N, H, W, O, + weights.kernel_size, pad, 0); + } + + // G1-Channel + { + acle::sme2::impl_collapsenet_channel_branch( + padded + 2 * N_pad, &weights.branch[0][2], output_blk1, H, W, O, + weights.kernel_size, pad, pad); + acle::sme2::impl_collapsenet_channel_branch( + output_blk1, &weights.branch[1][2], output_blk2 + 2 * N, H, W, O, + weights.kernel_size, pad, 0); + } + + // B-Channel + { + acle::sme2::impl_collapsenet_channel_branch( + padded + 3 * N_pad, &weights.branch[0][3], output_blk1, H, W, O, + weights.kernel_size, pad, pad); + acle::sme2::impl_collapsenet_channel_branch( + output_blk1, &weights.branch[1][3], output_blk2 + 3 * N, H, W, O, + weights.kernel_size, pad, 0); + } + + // --- Interleave channels --- + acle::sme2::interleave_channels(output_blk2, dst, N); + + // Delete intermediate buffers + delete[] output_blk1; + delete[] output_blk2; + delete[] padded; +} + +void impl_collapsenet(float16_t *src, collapsenet_weights_float16 &weights, + float16_t *dst, const size_t H, const size_t W, + const size_t O) { + + // Compute padding + uint64_t pad_2x = weights.kernel_size - 1; + uint64_t pad = pad_2x / 2; + + // Compute number of elements + uint64_t N = H * W; + uint64_t N_pad = (H + pad_2x) * (W + pad_2x); + + // Allocate buffers for intermediate data + float16_t *output_blk1 = new float16_t[N_pad](); + float16_t *output_blk2 = new float16_t[N * 4]; + + // --- Deinterleave channels and pad --- + float16_t *padded = new float16_t[N_pad * 4](); + acle::sme2::deinterleave_and_pad_channels(src, padded, H, W, pad); + + // --- Run CollapseNet blocks --- + + // R-Channel + { + acle::sme2::impl_collapsenet_channel_branch(padded, &weights.branch[0][0], + output_blk1, H, W, O, + weights.kernel_size, pad, pad); + acle::sme2::impl_collapsenet_channel_branch( + output_blk1, &weights.branch[1][0], output_blk2, H, W, O, + weights.kernel_size, pad, 0); + } + + // G0-Channel + { + acle::sme2::impl_collapsenet_channel_branch( + padded + N_pad, &weights.branch[0][1], output_blk1, H, W, O, + weights.kernel_size, pad, pad); + acle::sme2::impl_collapsenet_channel_branch( + output_blk1, &weights.branch[1][1], output_blk2 + N, H, W, O, + weights.kernel_size, pad, 0); + } + + // G1-Channel + { + acle::sme2::impl_collapsenet_channel_branch( + padded + 2 * N_pad, &weights.branch[0][2], output_blk1, H, W, O, + weights.kernel_size, pad, pad); + acle::sme2::impl_collapsenet_channel_branch( + output_blk1, &weights.branch[1][2], output_blk2 + 2 * N, H, W, O, + weights.kernel_size, pad, 0); + } + + // B-Channel + { + acle::sme2::impl_collapsenet_channel_branch( + padded + 3 * N_pad, &weights.branch[0][3], output_blk1, H, W, O, + weights.kernel_size, pad, pad); + acle::sme2::impl_collapsenet_channel_branch( + output_blk1, &weights.branch[1][3], output_blk2 + 3 * N, H, W, O, + weights.kernel_size, pad, 0); + } + + // --- Interleave channels --- + acle::sme2::interleave_channels(output_blk2, dst, N); + + // Delete intermediate buffers + delete[] output_blk1; + delete[] output_blk2; + delete[] padded; +} + +void impl_collapsenet_v1_1(float *src, + collapsenet_v1_1_weights_float32 &weights, + float *dst, const size_t H, const size_t W, + const size_t O) { + + // Compute resolution for all scales + uint64_t Hh = H / 2; + uint64_t Hq = H / 4; + uint64_t Wh = W / 2; + uint64_t Wq = W / 4; + + // Compute padding + uint64_t pad_2x = weights.kernel_size - 1; + uint64_t pad = pad_2x / 2; + + // Compute number of elements + uint64_t N = H * W; + uint64_t N_pad = (H + pad_2x) * (W + pad_2x); + uint64_t Nh_pad = (Hh + pad_2x) * (Wh + pad_2x); + uint64_t Nq_pad = (Hq + pad_2x) * (Wq + pad_2x); + + // Allocate buffers for intermediate data + float *buff_input_half = new float[Nh_pad](); + float *buff_interm_half = new float[Nh_pad](); + float *buff_output_half = new float[Nh_pad](); + float *buff_input_quarter = new float[Nq_pad](); + float *buff_output_quarter = new float[Nq_pad](); + float *buff_output = new float[N * 4]; + + // --- Deinterleave channels and pad --- + float *buff_padded = new float[N_pad * 4](); + acle::sme2::deinterleave_and_pad_channels(src, buff_padded, H, W, pad); + + // --- Run CollapseNet blocks --- + + // U-Channel + { + // Downsample by 2x and 4x + acle::sme2::impl_2x2_bilinear_downsample(buff_padded, buff_input_half, H, W, + pad, pad); + acle::sme2::impl_2x2_bilinear_downsample( + buff_input_half, buff_input_quarter, Hh, Wh, pad, pad); + + // Run first block at quarter resolution + acle::sme2::impl_collapsenet_channel_branch( + buff_input_quarter, &weights.branch_u[0], buff_output_quarter, Hq, Wq, + O, weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at half resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + buff_input_half, buff_input_quarter, buff_output_quarter, + buff_interm_half, Hh, Wh, pad, pad); + + // Run second block at half resolution + acle::sme2::impl_collapsenet_channel_branch( + buff_interm_half, &weights.branch_u[1], buff_output_half, Hh, Wh, O, + weights.kernel_size, pad, pad); + + // Upsample to full resolution + acle::sme2::impl_2x2_bilinear_upsample(buff_output_half, buff_output, Hh, + Wh, pad, 0); + } + + // G0-Channel + { + // Downsample by 2x and 4x + acle::sme2::impl_2x2_bilinear_downsample(buff_padded + N_pad, + buff_input_half, H, W, pad, pad); + acle::sme2::impl_2x2_bilinear_downsample( + buff_input_half, buff_input_quarter, Hh, Wh, pad, pad); + + // Run first block at quarter resolution + acle::sme2::impl_collapsenet_channel_branch( + buff_input_quarter, &weights.branch_g1[0], buff_output_quarter, Hq, Wq, + O, weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at half resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + buff_input_half, buff_input_quarter, buff_output_quarter, + buff_interm_half, Hh, Wh, pad, pad); + + // Run second block at half resolution + acle::sme2::impl_collapsenet_channel_branch( + buff_interm_half, &weights.branch_g1[1], buff_output_half, Hh, Wh, O, + weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at full resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + buff_padded + N_pad, buff_input_half, buff_output_half, + buff_padded + N_pad, H, W, pad, pad); + + // Run third block at full resolution + acle::sme2::impl_collapsenet_channel_branch( + buff_padded + N_pad, &weights.branch_g1[2], buff_output + N, H, W, O, + weights.kernel_size, pad, 0); + } + + // G1-Channel + { + // Downsample by 2x and 4x + acle::sme2::impl_2x2_bilinear_downsample(buff_padded + 2 * N_pad, + buff_input_half, H, W, pad, pad); + acle::sme2::impl_2x2_bilinear_downsample( + buff_input_half, buff_input_quarter, Hh, Wh, pad, pad); + + // Run first block at quarter resolution + acle::sme2::impl_collapsenet_channel_branch( + buff_input_quarter, &weights.branch_g2[0], buff_output_quarter, Hq, Wq, + O, weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at half resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + buff_input_half, buff_input_quarter, buff_output_quarter, + buff_interm_half, Hh, Wh, pad, pad); + + // Run second block at half resolution + acle::sme2::impl_collapsenet_channel_branch( + buff_interm_half, &weights.branch_g2[1], buff_output_half, Hh, Wh, O, + weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at full resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + buff_padded + 2 * N_pad, buff_input_half, buff_output_half, + buff_padded + 2 * N_pad, H, W, pad, pad); + + // Run third block at full resolution + acle::sme2::impl_collapsenet_channel_branch( + buff_padded + 2 * N_pad, &weights.branch_g2[2], buff_output + 2 * N, H, + W, O, weights.kernel_size, pad, 0); + } + + // V-Channel + { + // Downsample by 2x and 4x + acle::sme2::impl_2x2_bilinear_downsample(buff_padded + 3 * N_pad, + buff_input_half, H, W, pad, pad); + acle::sme2::impl_2x2_bilinear_downsample( + buff_input_half, buff_input_quarter, Hh, Wh, pad, pad); + + // Run first block at quarter resolution + acle::sme2::impl_collapsenet_channel_branch( + buff_input_quarter, &weights.branch_v[0], buff_output_quarter, Hq, Wq, + O, weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at half resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + buff_input_half, buff_input_quarter, buff_output_quarter, + buff_interm_half, Hh, Wh, pad, pad); + + // Run second block at half resolution + acle::sme2::impl_collapsenet_channel_branch( + buff_interm_half, &weights.branch_v[1], buff_output_half, Hh, Wh, O, + weights.kernel_size, pad, pad); + + // Upsample to full resolution + acle::sme2::impl_2x2_bilinear_upsample(buff_output_half, + buff_output + 3 * N, Hh, Wh, pad, 0); + } + + // --- Interleave channels --- + acle::sme2::interleave_uggv_to_rggb_channels(buff_output, dst, N); + + // Delete intermediate buffers + delete[] buff_input_half; + delete[] buff_interm_half; + delete[] buff_output_half; + delete[] buff_input_quarter; + delete[] buff_output_quarter; + delete[] buff_output; + delete[] buff_padded; +} + +void impl_collapsenet_v1_1(float16_t *src, + collapsenet_v1_1_weights_float16 &weights, + float16_t *dst, const size_t H, const size_t W, + const size_t O) { + + // Compute resolution for all scales + uint64_t Hh = H / 2; + uint64_t Hq = H / 4; + uint64_t Wh = W / 2; + uint64_t Wq = W / 4; + + // Compute padding + uint64_t pad_2x = weights.kernel_size - 1; + uint64_t pad = pad_2x / 2; + + // Compute number of elements + uint64_t N = H * W; + uint64_t N_pad = (H + pad_2x) * (W + pad_2x); + uint64_t Nh_pad = (Hh + pad_2x) * (Wh + pad_2x); + uint64_t Nq_pad = (Hq + pad_2x) * (Wq + pad_2x); + + // Allocate buffers for intermediate data + float16_t *buff_input_half = new float16_t[Nh_pad](); + float16_t *buff_interm_half = new float16_t[Nh_pad](); + float16_t *buff_output_half = new float16_t[Nh_pad](); + float16_t *buff_input_quarter = new float16_t[Nq_pad](); + float16_t *buff_output_quarter = new float16_t[Nq_pad](); + float16_t *buff_output = new float16_t[N * 4]; + + // --- Deinterleave channels and pad --- + float16_t *buff_padded = new float16_t[N_pad * 4](); + acle::sme2::deinterleave_and_pad_channels(src, buff_padded, H, W, pad); + + // --- Run CollapseNet blocks --- + + // U-Channel + { + // Downsample by 2x and 4x + acle::sme2::impl_2x2_bilinear_downsample(buff_padded, buff_input_half, H, W, + pad, pad); + acle::sme2::impl_2x2_bilinear_downsample( + buff_input_half, buff_input_quarter, Hh, Wh, pad, pad); + + // Run first block at quarter resolution + acle::sme2::impl_collapsenet_channel_branch( + buff_input_quarter, &weights.branch_u[0], buff_output_quarter, Hq, Wq, + O, weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at half resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + buff_input_half, buff_input_quarter, buff_output_quarter, + buff_interm_half, Hh, Wh, pad, pad); + + // Run second block at half resolution + acle::sme2::impl_collapsenet_channel_branch( + buff_interm_half, &weights.branch_u[1], buff_output_half, Hh, Wh, O, + weights.kernel_size, pad, pad); + + // Upsample to full resolution + acle::sme2::impl_2x2_bilinear_upsample(buff_output_half, buff_output, Hh, + Wh, pad, 0); + } + + // G0-Channel + { + // Downsample by 2x and 4x + acle::sme2::impl_2x2_bilinear_downsample(buff_padded + N_pad, + buff_input_half, H, W, pad, pad); + acle::sme2::impl_2x2_bilinear_downsample( + buff_input_half, buff_input_quarter, Hh, Wh, pad, pad); + + // Run first block at quarter resolution + acle::sme2::impl_collapsenet_channel_branch( + buff_input_quarter, &weights.branch_g1[0], buff_output_quarter, Hq, Wq, + O, weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at half resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + buff_input_half, buff_input_quarter, buff_output_quarter, + buff_interm_half, Hh, Wh, pad, pad); + + // Run second block at half resolution + acle::sme2::impl_collapsenet_channel_branch( + buff_interm_half, &weights.branch_g1[1], buff_output_half, Hh, Wh, O, + weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at full resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + buff_padded + N_pad, buff_input_half, buff_output_half, + buff_padded + N_pad, H, W, pad, pad); + + // Run third block at full resolution + acle::sme2::impl_collapsenet_channel_branch( + buff_padded + N_pad, &weights.branch_g1[2], buff_output + N, H, W, O, + weights.kernel_size, pad, 0); + } + + // G1-Channel + { + // Downsample by 2x and 4x + acle::sme2::impl_2x2_bilinear_downsample(buff_padded + 2 * N_pad, + buff_input_half, H, W, pad, pad); + acle::sme2::impl_2x2_bilinear_downsample( + buff_input_half, buff_input_quarter, Hh, Wh, pad, pad); + + // Run first block at quarter resolution + acle::sme2::impl_collapsenet_channel_branch( + buff_input_quarter, &weights.branch_g2[0], buff_output_quarter, Hq, Wq, + O, weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at half resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + buff_input_half, buff_input_quarter, buff_output_quarter, + buff_interm_half, Hh, Wh, pad, pad); + + // Run second block at half resolution + acle::sme2::impl_collapsenet_channel_branch( + buff_interm_half, &weights.branch_g2[1], buff_output_half, Hh, Wh, O, + weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at full resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + buff_padded + 2 * N_pad, buff_input_half, buff_output_half, + buff_padded + 2 * N_pad, H, W, pad, pad); + + // Run third block at full resolution + acle::sme2::impl_collapsenet_channel_branch( + buff_padded + 2 * N_pad, &weights.branch_g2[2], buff_output + 2 * N, H, + W, O, weights.kernel_size, pad, 0); + } + + // V-Channel + { + // Downsample by 2x and 4x + acle::sme2::impl_2x2_bilinear_downsample(buff_padded + 3 * N_pad, + buff_input_half, H, W, pad, pad); + acle::sme2::impl_2x2_bilinear_downsample( + buff_input_half, buff_input_quarter, Hh, Wh, pad, pad); + + // Run first block at quarter resolution + acle::sme2::impl_collapsenet_channel_branch( + buff_input_quarter, &weights.branch_v[0], buff_output_quarter, Hq, Wq, + O, weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at half resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + buff_input_half, buff_input_quarter, buff_output_quarter, + buff_interm_half, Hh, Wh, pad, pad); + + // Run second block at half resolution + acle::sme2::impl_collapsenet_channel_branch( + buff_interm_half, &weights.branch_v[1], buff_output_half, Hh, Wh, O, + weights.kernel_size, pad, pad); + + // Upsample to full resolution + acle::sme2::impl_2x2_bilinear_upsample(buff_output_half, + buff_output + 3 * N, Hh, Wh, pad, 0); + } + + // --- Interleave channels --- + acle::sme2::interleave_uggv_to_rggb_channels(buff_output, dst, N); + + // Delete intermediate buffers + delete[] buff_input_half; + delete[] buff_interm_half; + delete[] buff_output_half; + delete[] buff_input_quarter; + delete[] buff_output_quarter; + delete[] buff_output; + delete[] buff_padded; +} + +void impl_collapsenet_v1_1_uggv_out_wo_pad_interleave( + float *src, collapsenet_v1_1_weights_float32 &weights, float *dst, + const size_t H, const size_t W, const size_t O) { + + // Compute resolution for all scales + uint64_t Hh = H / 2; + uint64_t Hq = H / 4; + uint64_t Wh = W / 2; + uint64_t Wq = W / 4; + + // Compute padding + uint64_t pad_2x = weights.kernel_size - 1; + uint64_t pad = pad_2x / 2; + + // Compute number of elements + uint64_t N = H * W; + uint64_t N_pad = (H + pad_2x) * (W + pad_2x); + uint64_t Nh_pad = (Hh + pad_2x) * (Wh + pad_2x); + uint64_t Nq_pad = (Hq + pad_2x) * (Wq + pad_2x); + + // Allocate buffers for intermediate data + float *buff_padded = new float[N_pad](); + float *buff_input_half = new float[Nh_pad](); + float *buff_interm_half = new float[Nh_pad](); + float *buff_output_half = new float[Nh_pad](); + float *buff_input_quarter = new float[Nq_pad](); + float *buff_output_quarter = new float[Nq_pad](); + + // --- Run CollapseNet blocks --- + + // U-Channel + { + // Downsample by 2x and 4x + acle::sme2::impl_2x2_bilinear_downsample(src, buff_input_half, H, W, pad, + pad); + acle::sme2::impl_2x2_bilinear_downsample( + buff_input_half, buff_input_quarter, Hh, Wh, pad, pad); + + // Run first block at quarter resolution + acle::sme2::impl_collapsenet_channel_branch( + buff_input_quarter, &weights.branch_u[0], buff_output_quarter, Hq, Wq, + O, weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at half resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + buff_input_half, buff_input_quarter, buff_output_quarter, + buff_interm_half, Hh, Wh, pad, pad); + + // Run second block at half resolution + acle::sme2::impl_collapsenet_channel_branch( + buff_interm_half, &weights.branch_u[1], buff_output_half, Hh, Wh, O, + weights.kernel_size, pad, pad); + + // Upsample to full resolution + acle::sme2::impl_2x2_bilinear_upsample(buff_output_half, dst, Hh, Wh, pad, + 0); + } + + // G0-Channel + { + // Downsample by 2x and 4x + acle::sme2::impl_2x2_bilinear_downsample(src + N_pad, buff_input_half, H, W, + pad, pad); + acle::sme2::impl_2x2_bilinear_downsample( + buff_input_half, buff_input_quarter, Hh, Wh, pad, pad); + + // Run first block at quarter resolution + acle::sme2::impl_collapsenet_channel_branch( + buff_input_quarter, &weights.branch_g1[0], buff_output_quarter, Hq, Wq, + O, weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at half resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + buff_input_half, buff_input_quarter, buff_output_quarter, + buff_interm_half, Hh, Wh, pad, pad); + + // Run second block at half resolution + acle::sme2::impl_collapsenet_channel_branch( + buff_interm_half, &weights.branch_g1[1], buff_output_half, Hh, Wh, O, + weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at full resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + src + N_pad, buff_input_half, buff_output_half, buff_padded, H, W, pad, + pad); + + // Run third block at full resolution + acle::sme2::impl_collapsenet_channel_branch( + buff_padded, &weights.branch_g1[2], dst + N, H, W, O, + weights.kernel_size, pad, 0); + } + + // G1-Channel + { + // Downsample by 2x and 4x + acle::sme2::impl_2x2_bilinear_downsample(src + 2 * N_pad, buff_input_half, + H, W, pad, pad); + acle::sme2::impl_2x2_bilinear_downsample( + buff_input_half, buff_input_quarter, Hh, Wh, pad, pad); + + // Run first block at quarter resolution + acle::sme2::impl_collapsenet_channel_branch( + buff_input_quarter, &weights.branch_g2[0], buff_output_quarter, Hq, Wq, + O, weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at half resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + buff_input_half, buff_input_quarter, buff_output_quarter, + buff_interm_half, Hh, Wh, pad, pad); + + // Run second block at half resolution + acle::sme2::impl_collapsenet_channel_branch( + buff_interm_half, &weights.branch_g2[1], buff_output_half, Hh, Wh, O, + weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at full resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + src + 2 * N_pad, buff_input_half, buff_output_half, buff_padded, H, W, + pad, pad); + + // Run third block at full resolution + acle::sme2::impl_collapsenet_channel_branch( + buff_padded, &weights.branch_g2[2], dst + 2 * N, H, W, O, + weights.kernel_size, pad, 0); + } + + // V-Channel + { + // Downsample by 2x and 4x + acle::sme2::impl_2x2_bilinear_downsample(src + 3 * N_pad, buff_input_half, + H, W, pad, pad); + acle::sme2::impl_2x2_bilinear_downsample( + buff_input_half, buff_input_quarter, Hh, Wh, pad, pad); + + // Run first block at quarter resolution + acle::sme2::impl_collapsenet_channel_branch( + buff_input_quarter, &weights.branch_v[0], buff_output_quarter, Hq, Wq, + O, weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at half resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + buff_input_half, buff_input_quarter, buff_output_quarter, + buff_interm_half, Hh, Wh, pad, pad); + + // Run second block at half resolution + acle::sme2::impl_collapsenet_channel_branch( + buff_interm_half, &weights.branch_v[1], buff_output_half, Hh, Wh, O, + weights.kernel_size, pad, pad); + + // Upsample to full resolution + acle::sme2::impl_2x2_bilinear_upsample(buff_output_half, dst + 3 * N, Hh, + Wh, pad, 0); + } + + // Delete intermediate buffers + delete[] buff_padded; + delete[] buff_input_half; + delete[] buff_interm_half; + delete[] buff_output_half; + delete[] buff_input_quarter; + delete[] buff_output_quarter; +} + +void impl_collapsenet_v1_1_uggv_out_wo_pad_interleave( + float16_t *src, collapsenet_v1_1_weights_float16 &weights, float16_t *dst, + const size_t H, const size_t W, const size_t O) { + + // Compute resolution for all scales + uint64_t Hh = H / 2; + uint64_t Hq = H / 4; + uint64_t Wh = W / 2; + uint64_t Wq = W / 4; + + // Compute padding + uint64_t pad_2x = weights.kernel_size - 1; + uint64_t pad = pad_2x / 2; + + // Compute number of elements + uint64_t N = H * W; + uint64_t N_pad = (H + pad_2x) * (W + pad_2x); + uint64_t Nh_pad = (Hh + pad_2x) * (Wh + pad_2x); + uint64_t Nq_pad = (Hq + pad_2x) * (Wq + pad_2x); + + // Allocate buffers for intermediate data + float16_t *buff_padded = new float16_t[N_pad](); + float16_t *buff_input_half = new float16_t[Nh_pad](); + float16_t *buff_interm_half = new float16_t[Nh_pad](); + float16_t *buff_output_half = new float16_t[Nh_pad](); + float16_t *buff_input_quarter = new float16_t[Nq_pad](); + float16_t *buff_output_quarter = new float16_t[Nq_pad](); + + // --- Run CollapseNet blocks --- + + // U-Channel + { + // Downsample by 2x and 4x + acle::sme2::impl_2x2_bilinear_downsample(src, buff_input_half, H, W, pad, + pad); + acle::sme2::impl_2x2_bilinear_downsample( + buff_input_half, buff_input_quarter, Hh, Wh, pad, pad); + + // Run first block at quarter resolution + acle::sme2::impl_collapsenet_channel_branch( + buff_input_quarter, &weights.branch_u[0], buff_output_quarter, Hq, Wq, + O, weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at half resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + buff_input_half, buff_input_quarter, buff_output_quarter, + buff_interm_half, Hh, Wh, pad, pad); + + // Run second block at half resolution + acle::sme2::impl_collapsenet_channel_branch( + buff_interm_half, &weights.branch_u[1], buff_output_half, Hh, Wh, O, + weights.kernel_size, pad, pad); + + // Upsample to full resolution + acle::sme2::impl_2x2_bilinear_upsample(buff_output_half, dst, Hh, Wh, pad, + 0); + } + + // G0-Channel + { + // Downsample by 2x and 4x + acle::sme2::impl_2x2_bilinear_downsample(src + N_pad, buff_input_half, H, W, + pad, pad); + acle::sme2::impl_2x2_bilinear_downsample( + buff_input_half, buff_input_quarter, Hh, Wh, pad, pad); + + // Run first block at quarter resolution + acle::sme2::impl_collapsenet_channel_branch( + buff_input_quarter, &weights.branch_g1[0], buff_output_quarter, Hq, Wq, + O, weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at half resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + buff_input_half, buff_input_quarter, buff_output_quarter, + buff_interm_half, Hh, Wh, pad, pad); + + // Run second block at half resolution + acle::sme2::impl_collapsenet_channel_branch( + buff_interm_half, &weights.branch_g1[1], buff_output_half, Hh, Wh, O, + weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at full resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + src + N_pad, buff_input_half, buff_output_half, buff_padded, H, W, pad, + pad); + + // Run third block at full resolution + acle::sme2::impl_collapsenet_channel_branch( + buff_padded, &weights.branch_g1[2], dst + N, H, W, O, + weights.kernel_size, pad, 0); + } + + // G1-Channel + { + // Downsample by 2x and 4x + acle::sme2::impl_2x2_bilinear_downsample(src + 2 * N_pad, buff_input_half, + H, W, pad, pad); + acle::sme2::impl_2x2_bilinear_downsample( + buff_input_half, buff_input_quarter, Hh, Wh, pad, pad); + + // Run first block at quarter resolution + acle::sme2::impl_collapsenet_channel_branch( + buff_input_quarter, &weights.branch_g2[0], buff_output_quarter, Hq, Wq, + O, weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at half resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + buff_input_half, buff_input_quarter, buff_output_quarter, + buff_interm_half, Hh, Wh, pad, pad); + + // Run second block at half resolution + acle::sme2::impl_collapsenet_channel_branch( + buff_interm_half, &weights.branch_g2[1], buff_output_half, Hh, Wh, O, + weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at full resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + src + 2 * N_pad, buff_input_half, buff_output_half, buff_padded, H, W, + pad, pad); + + // Run third block at full resolution + acle::sme2::impl_collapsenet_channel_branch( + buff_padded, &weights.branch_g2[2], dst + 2 * N, H, W, O, + weights.kernel_size, pad, 0); + } + + // V-Channel + { + // Downsample by 2x and 4x + acle::sme2::impl_2x2_bilinear_downsample(src + 3 * N_pad, buff_input_half, + H, W, pad, pad); + acle::sme2::impl_2x2_bilinear_downsample( + buff_input_half, buff_input_quarter, Hh, Wh, pad, pad); + + // Run first block at quarter resolution + acle::sme2::impl_collapsenet_channel_branch( + buff_input_quarter, &weights.branch_v[0], buff_output_quarter, Hq, Wq, + O, weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at half resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + buff_input_half, buff_input_quarter, buff_output_quarter, + buff_interm_half, Hh, Wh, pad, pad); + + // Run second block at half resolution + acle::sme2::impl_collapsenet_channel_branch( + buff_interm_half, &weights.branch_v[1], buff_output_half, Hh, Wh, O, + weights.kernel_size, pad, pad); + + // Upsample to full resolution + acle::sme2::impl_2x2_bilinear_upsample(buff_output_half, dst + 3 * N, Hh, + Wh, pad, 0); + } + + // Delete intermediate buffers + delete[] buff_padded; + delete[] buff_input_half; + delete[] buff_interm_half; + delete[] buff_output_half; + delete[] buff_input_quarter; + delete[] buff_output_quarter; +} + +void impl_collapsenet_v1_1_uggv_out_wo_pad_interleave_interp_unfused( + float16_t *src, collapsenet_v1_1_weights_interp_unfused_float16 &weights, + float16_t *dst, const size_t H, const size_t W, const size_t O) { + + // Compute resolution for all scales + uint64_t Hh = H / 2; + uint64_t Hq = H / 4; + uint64_t Wh = W / 2; + uint64_t Wq = W / 4; + + // Compute padding + uint64_t pad_2x = weights.kernel_size - 1; + uint64_t pad = pad_2x / 2; + + // Compute number of elements + uint64_t N = H * W; + uint64_t N_pad = (H + pad_2x) * (W + pad_2x); + uint64_t Nh_pad = (Hh + pad_2x) * (Wh + pad_2x); + uint64_t Nq_pad = (Hq + pad_2x) * (Wq + pad_2x); + + // Allocate buffers for intermediate data + float16_t *buff_padded = new float16_t[N_pad](); + float16_t *buff_input_half = new float16_t[Nh_pad](); + float16_t *buff_interm_half = new float16_t[Nh_pad](); + float16_t *buff_output_half = new float16_t[Nh_pad](); + float16_t *buff_input_quarter = new float16_t[Nq_pad](); + float16_t *buff_output_quarter = new float16_t[Nq_pad](); + + // --- Run CollapseNet blocks --- + + // U-Channel + { + // Downsample by 2x and 4x + acle::sme2::impl_2x2_bilinear_downsample(src, buff_input_half, H, W, pad, + pad); + acle::sme2::impl_2x2_bilinear_downsample( + buff_input_half, buff_input_quarter, Hh, Wh, pad, pad); + + // Run first block at quarter resolution + acle::sme2::impl_collapsenet_channel_branch_interp_unfused( + buff_input_quarter, &weights.branch_u[0], buff_output_quarter, Hq, Wq, + O, weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at half resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + buff_input_half, buff_input_quarter, buff_output_quarter, + buff_interm_half, Hh, Wh, pad, pad); + + // Run second block at half resolution + acle::sme2::impl_collapsenet_channel_branch_interp_unfused( + buff_interm_half, &weights.branch_u[1], buff_output_half, Hh, Wh, O, + weights.kernel_size, pad, pad); + + // Upsample to full resolution + acle::sme2::impl_2x2_bilinear_upsample(buff_output_half, dst, Hh, Wh, pad, + 0); + } + + // G0-Channel + { + // Downsample by 2x and 4x + acle::sme2::impl_2x2_bilinear_downsample(src + N_pad, buff_input_half, H, W, + pad, pad); + acle::sme2::impl_2x2_bilinear_downsample( + buff_input_half, buff_input_quarter, Hh, Wh, pad, pad); + + // Run first block at quarter resolution + acle::sme2::impl_collapsenet_channel_branch_interp_unfused( + buff_input_quarter, &weights.branch_g1[0], buff_output_quarter, Hq, Wq, + O, weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at half resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + buff_input_half, buff_input_quarter, buff_output_quarter, + buff_interm_half, Hh, Wh, pad, pad); + + // Run second block at half resolution + acle::sme2::impl_collapsenet_channel_branch_interp_unfused( + buff_interm_half, &weights.branch_g1[1], buff_output_half, Hh, Wh, O, + weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at full resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + src + N_pad, buff_input_half, buff_output_half, buff_padded, H, W, pad, + pad); + + // Run third block at full resolution + acle::sme2::impl_collapsenet_channel_branch_interp_unfused( + buff_padded, &weights.branch_g1[2], dst + N, H, W, O, + weights.kernel_size, pad, 0); + } + + // G1-Channel + { + // Downsample by 2x and 4x + acle::sme2::impl_2x2_bilinear_downsample(src + 2 * N_pad, buff_input_half, + H, W, pad, pad); + acle::sme2::impl_2x2_bilinear_downsample( + buff_input_half, buff_input_quarter, Hh, Wh, pad, pad); + + // Run first block at quarter resolution + acle::sme2::impl_collapsenet_channel_branch_interp_unfused( + buff_input_quarter, &weights.branch_g2[0], buff_output_quarter, Hq, Wq, + O, weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at half resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + buff_input_half, buff_input_quarter, buff_output_quarter, + buff_interm_half, Hh, Wh, pad, pad); + + // Run second block at half resolution + acle::sme2::impl_collapsenet_channel_branch_interp_unfused( + buff_interm_half, &weights.branch_g2[1], buff_output_half, Hh, Wh, O, + weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at full resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + src + 2 * N_pad, buff_input_half, buff_output_half, buff_padded, H, W, + pad, pad); + + // Run third block at full resolution + acle::sme2::impl_collapsenet_channel_branch_interp_unfused( + buff_padded, &weights.branch_g2[2], dst + 2 * N, H, W, O, + weights.kernel_size, pad, 0); + } + + // V-Channel + { + // Downsample by 2x and 4x + acle::sme2::impl_2x2_bilinear_downsample(src + 3 * N_pad, buff_input_half, + H, W, pad, pad); + acle::sme2::impl_2x2_bilinear_downsample( + buff_input_half, buff_input_quarter, Hh, Wh, pad, pad); + + // Run first block at quarter resolution + acle::sme2::impl_collapsenet_channel_branch_interp_unfused( + buff_input_quarter, &weights.branch_v[0], buff_output_quarter, Hq, Wq, + O, weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at half resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + buff_input_half, buff_input_quarter, buff_output_quarter, + buff_interm_half, Hh, Wh, pad, pad); + + // Run second block at half resolution + acle::sme2::impl_collapsenet_channel_branch_interp_unfused( + buff_interm_half, &weights.branch_v[1], buff_output_half, Hh, Wh, O, + weights.kernel_size, pad, pad); + + // Upsample to full resolution + acle::sme2::impl_2x2_bilinear_upsample(buff_output_half, dst + 3 * N, Hh, + Wh, pad, 0); + } + + // Delete intermediate buffers + delete[] buff_padded; + delete[] buff_input_half; + delete[] buff_interm_half; + delete[] buff_output_half; + delete[] buff_input_quarter; + delete[] buff_output_quarter; +} + +void impl_collapsenet_v1_1_uggv_out_wo_pad_interleave_interp_collect_unfused( + float16_t *src, + collapsenet_v1_1_weights_interp_collect_unfused_float16 &weights, + float16_t *dst, const size_t H, const size_t W, const size_t O) { + + // Compute resolution for all scales + uint64_t Hh = H / 2; + uint64_t Hq = H / 4; + uint64_t Wh = W / 2; + uint64_t Wq = W / 4; + + // Compute padding + uint64_t pad_2x = weights.kernel_size - 1; + uint64_t pad = pad_2x / 2; + + // Compute number of elements + uint64_t N = H * W; + uint64_t N_pad = (H + pad_2x) * (W + pad_2x); + uint64_t Nh_pad = (Hh + pad_2x) * (Wh + pad_2x); + uint64_t Nq_pad = (Hq + pad_2x) * (Wq + pad_2x); + + // Allocate buffers for intermediate data + float16_t *buff_padded = new float16_t[N_pad](); + float16_t *buff_input_half = new float16_t[Nh_pad](); + float16_t *buff_interm_half = new float16_t[Nh_pad](); + float16_t *buff_output_half = new float16_t[Nh_pad](); + float16_t *buff_input_quarter = new float16_t[Nq_pad](); + float16_t *buff_output_quarter = new float16_t[Nq_pad](); + + // --- Run CollapseNet blocks --- + + // U-Channel + { + // Downsample by 2x and 4x + acle::sme2::impl_2x2_bilinear_downsample(src, buff_input_half, H, W, pad, + pad); + acle::sme2::impl_2x2_bilinear_downsample( + buff_input_half, buff_input_quarter, Hh, Wh, pad, pad); + + // Run first block at quarter resolution + acle::sme2::impl_collapsenet_channel_branch_interp_collect_unfused( + buff_input_quarter, &weights.branch_u[0], buff_output_quarter, Hq, Wq, + O, weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at half resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + buff_input_half, buff_input_quarter, buff_output_quarter, + buff_interm_half, Hh, Wh, pad, pad); + + // Run second block at half resolution + acle::sme2::impl_collapsenet_channel_branch_interp_collect_unfused( + buff_interm_half, &weights.branch_u[1], buff_output_half, Hh, Wh, O, + weights.kernel_size, pad, pad); + + // Upsample to full resolution + acle::sme2::impl_2x2_bilinear_upsample(buff_output_half, dst, Hh, Wh, pad, + 0); + } + + // G0-Channel + { + // Downsample by 2x and 4x + acle::sme2::impl_2x2_bilinear_downsample(src + N_pad, buff_input_half, H, W, + pad, pad); + acle::sme2::impl_2x2_bilinear_downsample( + buff_input_half, buff_input_quarter, Hh, Wh, pad, pad); + + // Run first block at quarter resolution + acle::sme2::impl_collapsenet_channel_branch_interp_collect_unfused( + buff_input_quarter, &weights.branch_g1[0], buff_output_quarter, Hq, Wq, + O, weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at half resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + buff_input_half, buff_input_quarter, buff_output_quarter, + buff_interm_half, Hh, Wh, pad, pad); + + // Run second block at half resolution + acle::sme2::impl_collapsenet_channel_branch_interp_collect_unfused( + buff_interm_half, &weights.branch_g1[1], buff_output_half, Hh, Wh, O, + weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at full resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + src + N_pad, buff_input_half, buff_output_half, buff_padded, H, W, pad, + pad); + + // Run third block at full resolution + acle::sme2::impl_collapsenet_channel_branch_interp_collect_unfused( + buff_padded, &weights.branch_g1[2], dst + N, H, W, O, + weights.kernel_size, pad, 0); + } + + // G1-Channel + { + // Downsample by 2x and 4x + acle::sme2::impl_2x2_bilinear_downsample(src + 2 * N_pad, buff_input_half, + H, W, pad, pad); + acle::sme2::impl_2x2_bilinear_downsample( + buff_input_half, buff_input_quarter, Hh, Wh, pad, pad); + + // Run first block at quarter resolution + acle::sme2::impl_collapsenet_channel_branch_interp_collect_unfused( + buff_input_quarter, &weights.branch_g2[0], buff_output_quarter, Hq, Wq, + O, weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at half resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + buff_input_half, buff_input_quarter, buff_output_quarter, + buff_interm_half, Hh, Wh, pad, pad); + + // Run second block at half resolution + acle::sme2::impl_collapsenet_channel_branch_interp_collect_unfused( + buff_interm_half, &weights.branch_g2[1], buff_output_half, Hh, Wh, O, + weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at full resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + src + 2 * N_pad, buff_input_half, buff_output_half, buff_padded, H, W, + pad, pad); + + // Run third block at full resolution + acle::sme2::impl_collapsenet_channel_branch_interp_collect_unfused( + buff_padded, &weights.branch_g2[2], dst + 2 * N, H, W, O, + weights.kernel_size, pad, 0); + } + + // V-Channel + { + // Downsample by 2x and 4x + acle::sme2::impl_2x2_bilinear_downsample(src + 3 * N_pad, buff_input_half, + H, W, pad, pad); + acle::sme2::impl_2x2_bilinear_downsample( + buff_input_half, buff_input_quarter, Hh, Wh, pad, pad); + + // Run first block at quarter resolution + acle::sme2::impl_collapsenet_channel_branch_interp_collect_unfused( + buff_input_quarter, &weights.branch_v[0], buff_output_quarter, Hq, Wq, + O, weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at half resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + buff_input_half, buff_input_quarter, buff_output_quarter, + buff_interm_half, Hh, Wh, pad, pad); + + // Run second block at half resolution + acle::sme2::impl_collapsenet_channel_branch_interp_collect_unfused( + buff_interm_half, &weights.branch_v[1], buff_output_half, Hh, Wh, O, + weights.kernel_size, pad, pad); + + // Upsample to full resolution + acle::sme2::impl_2x2_bilinear_upsample(buff_output_half, dst + 3 * N, Hh, + Wh, pad, 0); + } + + // Delete intermediate buffers + delete[] buff_padded; + delete[] buff_input_half; + delete[] buff_interm_half; + delete[] buff_output_half; + delete[] buff_input_quarter; + delete[] buff_output_quarter; +} + +void impl_collapsenet_v1_1_uggv_out_wo_pad_interleave_non_widening( + float16_t *src, collapsenet_v1_1_weights_float16 &weights, float16_t *dst, + const size_t H, const size_t W, const size_t O) { + + // Compute resolution for all scales + uint64_t Hh = H / 2; + uint64_t Hq = H / 4; + uint64_t Wh = W / 2; + uint64_t Wq = W / 4; + + // Compute padding + uint64_t pad_2x = weights.kernel_size - 1; + uint64_t pad = pad_2x / 2; + + // Compute number of elements + uint64_t N = H * W; + uint64_t N_pad = (H + pad_2x) * (W + pad_2x); + uint64_t Nh_pad = (Hh + pad_2x) * (Wh + pad_2x); + uint64_t Nq_pad = (Hq + pad_2x) * (Wq + pad_2x); + + // Allocate buffers for intermediate data + float16_t *buff_padded = new float16_t[N_pad](); + float16_t *buff_input_half = new float16_t[Nh_pad](); + float16_t *buff_interm_half = new float16_t[Nh_pad](); + float16_t *buff_output_half = new float16_t[Nh_pad](); + float16_t *buff_input_quarter = new float16_t[Nq_pad](); + float16_t *buff_output_quarter = new float16_t[Nq_pad](); + + // --- Run CollapseNet blocks --- + + // U-Channel + { + // Downsample by 2x and 4x + acle::sme2::impl_2x2_bilinear_downsample(src, buff_input_half, H, W, pad, + pad); + acle::sme2::impl_2x2_bilinear_downsample( + buff_input_half, buff_input_quarter, Hh, Wh, pad, pad); + + // Run first block at quarter resolution + acle::sme2::impl_collapsenet_channel_branch_non_widening( + buff_input_quarter, &weights.branch_u[0], buff_output_quarter, Hq, Wq, + O, weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at half resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + buff_input_half, buff_input_quarter, buff_output_quarter, + buff_interm_half, Hh, Wh, pad, pad); + + // Run second block at half resolution + acle::sme2::impl_collapsenet_channel_branch_non_widening( + buff_interm_half, &weights.branch_u[1], buff_output_half, Hh, Wh, O, + weights.kernel_size, pad, pad); + + // Upsample to full resolution + acle::sme2::impl_2x2_bilinear_upsample(buff_output_half, dst, Hh, Wh, pad, + 0); + } + + // G0-Channel + { + // Downsample by 2x and 4x + acle::sme2::impl_2x2_bilinear_downsample(src + N_pad, buff_input_half, H, W, + pad, pad); + acle::sme2::impl_2x2_bilinear_downsample( + buff_input_half, buff_input_quarter, Hh, Wh, pad, pad); + + // Run first block at quarter resolution + acle::sme2::impl_collapsenet_channel_branch_non_widening( + buff_input_quarter, &weights.branch_g1[0], buff_output_quarter, Hq, Wq, + O, weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at half resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + buff_input_half, buff_input_quarter, buff_output_quarter, + buff_interm_half, Hh, Wh, pad, pad); + + // Run second block at half resolution + acle::sme2::impl_collapsenet_channel_branch_non_widening( + buff_interm_half, &weights.branch_g1[1], buff_output_half, Hh, Wh, O, + weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at full resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + src + N_pad, buff_input_half, buff_output_half, buff_padded, H, W, pad, + pad); + + // Run third block at full resolution + acle::sme2::impl_collapsenet_channel_branch_non_widening( + buff_padded, &weights.branch_g1[2], dst + N, H, W, O, + weights.kernel_size, pad, 0); + } + + // G1-Channel + { + // Downsample by 2x and 4x + acle::sme2::impl_2x2_bilinear_downsample(src + 2 * N_pad, buff_input_half, + H, W, pad, pad); + acle::sme2::impl_2x2_bilinear_downsample( + buff_input_half, buff_input_quarter, Hh, Wh, pad, pad); + + // Run first block at quarter resolution + acle::sme2::impl_collapsenet_channel_branch_non_widening( + buff_input_quarter, &weights.branch_g2[0], buff_output_quarter, Hq, Wq, + O, weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at half resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + buff_input_half, buff_input_quarter, buff_output_quarter, + buff_interm_half, Hh, Wh, pad, pad); + + // Run second block at half resolution + acle::sme2::impl_collapsenet_channel_branch_non_widening( + buff_interm_half, &weights.branch_g2[1], buff_output_half, Hh, Wh, O, + weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at full resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + src + 2 * N_pad, buff_input_half, buff_output_half, buff_padded, H, W, + pad, pad); + + // Run third block at full resolution + acle::sme2::impl_collapsenet_channel_branch_non_widening( + buff_padded, &weights.branch_g2[2], dst + 2 * N, H, W, O, + weights.kernel_size, pad, 0); + } + + // V-Channel + { + // Downsample by 2x and 4x + acle::sme2::impl_2x2_bilinear_downsample(src + 3 * N_pad, buff_input_half, + H, W, pad, pad); + acle::sme2::impl_2x2_bilinear_downsample( + buff_input_half, buff_input_quarter, Hh, Wh, pad, pad); + + // Run first block at quarter resolution + acle::sme2::impl_collapsenet_channel_branch_non_widening( + buff_input_quarter, &weights.branch_v[0], buff_output_quarter, Hq, Wq, + O, weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at half resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + buff_input_half, buff_input_quarter, buff_output_quarter, + buff_interm_half, Hh, Wh, pad, pad); + + // Run second block at half resolution + acle::sme2::impl_collapsenet_channel_branch_non_widening( + buff_interm_half, &weights.branch_v[1], buff_output_half, Hh, Wh, O, + weights.kernel_size, pad, pad); + + // Upsample to full resolution + acle::sme2::impl_2x2_bilinear_upsample(buff_output_half, dst + 3 * N, Hh, + Wh, pad, 0); + } + + // Delete intermediate buffers + delete[] buff_padded; + delete[] buff_input_half; + delete[] buff_interm_half; + delete[] buff_output_half; + delete[] buff_input_quarter; + delete[] buff_output_quarter; +} + +void impl_collapsenet_v1_1_uggv_out_wo_pad_interleave_non_widening_mop4( + float16_t *src, collapsenet_v1_1_weights_float16 &weights, float16_t *dst, + const size_t H, const size_t W, const size_t O) { + + // Compute resolution for all scales + uint64_t Hh = H / 2; + uint64_t Hq = H / 4; + uint64_t Wh = W / 2; + uint64_t Wq = W / 4; + + // Compute padding + uint64_t pad_2x = weights.kernel_size - 1; + uint64_t pad = pad_2x / 2; + + // Compute number of elements + uint64_t N = H * W; + uint64_t N_pad = (H + pad_2x) * (W + pad_2x); + uint64_t Nh_pad = (Hh + pad_2x) * (Wh + pad_2x); + uint64_t Nq_pad = (Hq + pad_2x) * (Wq + pad_2x); + + // Allocate buffers for intermediate data + float16_t *buff_padded = new float16_t[N_pad](); + float16_t *buff_input_half = new float16_t[Nh_pad](); + float16_t *buff_interm_half = new float16_t[Nh_pad](); + float16_t *buff_output_half = new float16_t[Nh_pad](); + float16_t *buff_input_quarter = new float16_t[Nq_pad](); + float16_t *buff_output_quarter = new float16_t[Nq_pad](); + + // --- Run CollapseNet blocks --- + + // U-Channel + { + // Downsample by 2x and 4x + acle::sme2::impl_2x2_bilinear_downsample(src, buff_input_half, H, W, pad, + pad); + acle::sme2::impl_2x2_bilinear_downsample( + buff_input_half, buff_input_quarter, Hh, Wh, pad, pad); + + // Run first block at quarter resolution + acle::sme2::impl_collapsenet_channel_branch_non_widening_mop4( + buff_input_quarter, &weights.branch_u[0], buff_output_quarter, Hq, Wq, + O, weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at half resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + buff_input_half, buff_input_quarter, buff_output_quarter, + buff_interm_half, Hh, Wh, pad, pad); + + // Run second block at half resolution + acle::sme2::impl_collapsenet_channel_branch_non_widening_mop4( + buff_interm_half, &weights.branch_u[1], buff_output_half, Hh, Wh, O, + weights.kernel_size, pad, pad); + + // Upsample to full resolution + acle::sme2::impl_2x2_bilinear_upsample(buff_output_half, dst, Hh, Wh, pad, + 0); + } + + // G0-Channel + { + // Downsample by 2x and 4x + acle::sme2::impl_2x2_bilinear_downsample(src + N_pad, buff_input_half, H, W, + pad, pad); + acle::sme2::impl_2x2_bilinear_downsample( + buff_input_half, buff_input_quarter, Hh, Wh, pad, pad); + + // Run first block at quarter resolution + acle::sme2::impl_collapsenet_channel_branch_non_widening_mop4( + buff_input_quarter, &weights.branch_g1[0], buff_output_quarter, Hq, Wq, + O, weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at half resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + buff_input_half, buff_input_quarter, buff_output_quarter, + buff_interm_half, Hh, Wh, pad, pad); + + // Run second block at half resolution + acle::sme2::impl_collapsenet_channel_branch_non_widening_mop4( + buff_interm_half, &weights.branch_g1[1], buff_output_half, Hh, Wh, O, + weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at full resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + src + N_pad, buff_input_half, buff_output_half, buff_padded, H, W, pad, + pad); + + // Run third block at full resolution + acle::sme2::impl_collapsenet_channel_branch_non_widening_mop4( + buff_padded, &weights.branch_g1[2], dst + N, H, W, O, + weights.kernel_size, pad, 0); + } + + // G1-Channel + { + // Downsample by 2x and 4x + acle::sme2::impl_2x2_bilinear_downsample(src + 2 * N_pad, buff_input_half, + H, W, pad, pad); + acle::sme2::impl_2x2_bilinear_downsample( + buff_input_half, buff_input_quarter, Hh, Wh, pad, pad); + + // Run first block at quarter resolution + acle::sme2::impl_collapsenet_channel_branch_non_widening_mop4( + buff_input_quarter, &weights.branch_g2[0], buff_output_quarter, Hq, Wq, + O, weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at half resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + buff_input_half, buff_input_quarter, buff_output_quarter, + buff_interm_half, Hh, Wh, pad, pad); + + // Run second block at half resolution + acle::sme2::impl_collapsenet_channel_branch_non_widening_mop4( + buff_interm_half, &weights.branch_g2[1], buff_output_half, Hh, Wh, O, + weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at full resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + src + 2 * N_pad, buff_input_half, buff_output_half, buff_padded, H, W, + pad, pad); + + // Run third block at full resolution + acle::sme2::impl_collapsenet_channel_branch_non_widening_mop4( + buff_padded, &weights.branch_g2[2], dst + 2 * N, H, W, O, + weights.kernel_size, pad, 0); + } + + // V-Channel + { + // Downsample by 2x and 4x + acle::sme2::impl_2x2_bilinear_downsample(src + 3 * N_pad, buff_input_half, + H, W, pad, pad); + acle::sme2::impl_2x2_bilinear_downsample( + buff_input_half, buff_input_quarter, Hh, Wh, pad, pad); + + // Run first block at quarter resolution + acle::sme2::impl_collapsenet_channel_branch_non_widening_mop4( + buff_input_quarter, &weights.branch_v[0], buff_output_quarter, Hq, Wq, + O, weights.kernel_size, pad, pad); + + // Upsample and merge high and low frequencies at half resolution + acle::sme2::impl_2x2_bilinear_upsample_and_merge( + buff_input_half, buff_input_quarter, buff_output_quarter, + buff_interm_half, Hh, Wh, pad, pad); + + // Run second block at half resolution + acle::sme2::impl_collapsenet_channel_branch_non_widening_mop4( + buff_interm_half, &weights.branch_v[1], buff_output_half, Hh, Wh, O, + weights.kernel_size, pad, pad); + + // Upsample to full resolution + acle::sme2::impl_2x2_bilinear_upsample(buff_output_half, dst + 3 * N, Hh, + Wh, pad, 0); + } + + // Delete intermediate buffers + delete[] buff_padded; + delete[] buff_input_half; + delete[] buff_interm_half; + delete[] buff_output_half; + delete[] buff_input_quarter; + delete[] buff_output_quarter; +} + +} // namespace acle + +#endif \ No newline at end of file diff --git a/usecase/denoiser/core/collapse_net_sme2.cpp b/usecase/denoiser/core/collapse_net_sme2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6ab07fe64aad0cdf1875146b02f8f58c0a68809e --- /dev/null +++ b/usecase/denoiser/core/collapse_net_sme2.cpp @@ -0,0 +1,2380 @@ +#include "collapse_net.h" +#include "cc.h" + +#if ENABLE_SME2 +namespace acle { + +namespace sme2 { + +void interleave_channels(float *src, float *dst, const size_t N) ARM_STREAMING { + + float *src_c0 = src; + float *src_c1 = src + N; + float *src_c2 = src + 2 * N; + float *src_c3 = src + 3 * N; + + size_t veclen = svcntw(); + size_t veclen_4x = 4 * veclen; + for (uint64_t n = 0; n < N; n += veclen) { + svbool_t pg = svwhilelt_b32(n, N); + svst4_f32(pg, dst, + svcreate4_f32(svld1_f32(pg, src_c0), svld1_f32(pg, src_c1), + svld1_f32(pg, src_c2), svld1_f32(pg, src_c3))); + src_c0 += veclen; + src_c1 += veclen; + src_c2 += veclen; + src_c3 += veclen; + dst += veclen_4x; + } +} + +void interleave_channels(float16_t *src, float16_t *dst, + const size_t N) ARM_STREAMING { + + float16_t *src_c0 = src; + float16_t *src_c1 = src + N; + float16_t *src_c2 = src + 2 * N; + float16_t *src_c3 = src + 3 * N; + + size_t veclen = svcnth(); + size_t veclen_4x = 4 * veclen; + for (uint64_t n = 0; n < N; n += veclen) { + svbool_t pg = svwhilelt_b16(n, N); + svst4_f16(pg, dst, + svcreate4_f16(svld1_f16(pg, src_c0), svld1_f16(pg, src_c1), + svld1_f16(pg, src_c2), svld1_f16(pg, src_c3))); + src_c0 += veclen; + src_c1 += veclen; + src_c2 += veclen; + src_c3 += veclen; + dst += veclen_4x; + } +} + +void interleave_uggv_to_rggb_channels(float *src, float *dst, + const size_t N) ARM_STREAMING { + + float *src_c0 = src; + float *src_c1 = src + N; + float *src_c2 = src + 2 * N; + float *src_c3 = src + 3 * N; + + size_t veclen = svcntw(); + size_t veclen_4x = 4 * veclen; + for (uint64_t n = 0; n < N; n += veclen) { + svbool_t pg = svwhilelt_b32(n, N); + svfloat32x4_t uggv = + svcreate4_f32(svld1_f32(pg, src_c0), svld1_f32(pg, src_c1), + svld1_f32(pg, src_c2), svld1_f32(pg, src_c3)); + svfloat32x4_t rggb = uggv_to_rggb_f32(pg, uggv); + svst4_f32(pg, dst, rggb); + src_c0 += veclen; + src_c1 += veclen; + src_c2 += veclen; + src_c3 += veclen; + dst += veclen_4x; + } +} + +void interleave_uggv_to_rggb_channels(float16_t *src, float16_t *dst, + const size_t N) ARM_STREAMING { + + float16_t *src_c0 = src; + float16_t *src_c1 = src + N; + float16_t *src_c2 = src + 2 * N; + float16_t *src_c3 = src + 3 * N; + + size_t veclen = svcnth(); + size_t veclen_4x = 4 * veclen; + for (uint64_t n = 0; n < N; n += veclen) { + svbool_t pg = svwhilelt_b16(n, N); + svfloat16x4_t uggv = + svcreate4_f16(svld1_f16(pg, src_c0), svld1_f16(pg, src_c1), + svld1_f16(pg, src_c2), svld1_f16(pg, src_c3)); + svfloat16x4_t rggb = uggv_to_rggb_f16(pg, uggv); + svst4_f16(pg, dst, rggb); + src_c0 += veclen; + src_c1 += veclen; + src_c2 += veclen; + src_c3 += veclen; + dst += veclen_4x; + } +} + +void deinterleave_and_pad_channels(float *src, float *dst, const size_t H, + const size_t W, + const size_t pad) ARM_STREAMING { + + size_t pad_2x = 2 * pad; + + float *dst_c0 = dst; + float *dst_c1 = dst + (H + pad_2x) * (W + pad_2x); + float *dst_c2 = dst + 2 * (H + pad_2x) * (W + pad_2x); + float *dst_c3 = dst + 3 * (H + pad_2x) * (W + pad_2x); + + dst_c0 += (W + pad_2x + 1) * pad; + dst_c1 += (W + pad_2x + 1) * pad; + dst_c2 += (W + pad_2x + 1) * pad; + dst_c3 += (W + pad_2x + 1) * pad; + + size_t veclen = svcntw(); + for (uint64_t h = 0; h < H; h++) { + for (uint64_t w = 0; w < W; w += veclen) { + svbool_t pg = svwhilelt_b32(w, W); + uint64_t num_active = svcntp_b32(pg, pg); + uint64_t num_active_4x = 4 * num_active; + svfloat32x4_t src_vec = svld4_f32(pg, src); + svst1_f32(pg, dst_c0, svget4_f32(src_vec, 0)); + svst1_f32(pg, dst_c1, svget4_f32(src_vec, 1)); + svst1_f32(pg, dst_c2, svget4_f32(src_vec, 2)); + svst1_f32(pg, dst_c3, svget4_f32(src_vec, 3)); + dst_c0 += num_active; + dst_c1 += num_active; + dst_c2 += num_active; + dst_c3 += num_active; + src += num_active_4x; + } + dst_c0 += pad_2x; + dst_c1 += pad_2x; + dst_c2 += pad_2x; + dst_c3 += pad_2x; + } +} + +void deinterleave_and_pad_channels(float16_t *src, float16_t *dst, + const size_t H, const size_t W, + const size_t pad) ARM_STREAMING { + + size_t pad_2x = 2 * pad; + + float16_t *dst_c0 = dst; + float16_t *dst_c1 = dst + (H + pad_2x) * (W + pad_2x); + float16_t *dst_c2 = dst + 2 * (H + pad_2x) * (W + pad_2x); + float16_t *dst_c3 = dst + 3 * (H + pad_2x) * (W + pad_2x); + + dst_c0 += (W + pad_2x + 1) * pad; + dst_c1 += (W + pad_2x + 1) * pad; + dst_c2 += (W + pad_2x + 1) * pad; + dst_c3 += (W + pad_2x + 1) * pad; + + size_t veclen = svcnth(); + for (uint64_t h = 0; h < H; h++) { + for (uint64_t w = 0; w < W; w += veclen) { + svbool_t pg = svwhilelt_b16(w, W); + uint64_t num_active = svcntp_b16(pg, pg); + uint64_t num_active_4x = 4 * num_active; + svfloat16x4_t src_vec = svld4_f16(pg, src); + svst1_f16(pg, dst_c0, svget4_f16(src_vec, 0)); + svst1_f16(pg, dst_c1, svget4_f16(src_vec, 1)); + svst1_f16(pg, dst_c2, svget4_f16(src_vec, 2)); + svst1_f16(pg, dst_c3, svget4_f16(src_vec, 3)); + dst_c0 += num_active; + dst_c1 += num_active; + dst_c2 += num_active; + dst_c3 += num_active; + src += num_active_4x; + } + dst_c0 += pad_2x; + dst_c1 += pad_2x; + dst_c2 += pad_2x; + dst_c3 += pad_2x; + } +} + +void impl_2x2_bilinear_downsample(const float *src, float *dst, + const uint64_t src_height, + const uint64_t src_width, + const uint64_t src_pad, + const uint64_t dst_pad) ARM_STREAMING { + // Compute dst height and width + uint64_t dst_height = src_height / 2; + uint64_t dst_width = src_width / 2; + + // Compute stride + uint64_t src_stride = src_width + 2 * src_pad; + uint64_t dst_stride = dst_width + 2 * dst_pad; + + // Initialise pointers + float *src_ptr = const_cast(src) + src_pad * src_stride + src_pad; + float *dst_ptr = const_cast(dst) + dst_pad * dst_stride + dst_pad; + + // Main loop with SVE Intrinsic + uint64_t veclen = svcntw(); + + for (size_t h = 0; h < dst_height; h++) { + for (size_t w = 0; w < dst_width; w += veclen) { + + // Determine active elements of predicate register + svbool_t pg = svwhilelt_b32(w, dst_width); + uint64_t num_active = svcntp_b32(pg, pg); + + // Vector load + svfloat32x2_t src_vec0 = svld2_f32(pg, src_ptr); + svfloat32x2_t src_vec1 = svld2_f32(pg, src_ptr + src_stride); + + // Compute + svfloat32_t dst_vec = svdup_n_f32_x(pg, 0.0); + dst_vec = svmla_n_f32_x(pg, dst_vec, svget2_f32(src_vec0, 0), 0.25); + dst_vec = svmla_n_f32_x(pg, dst_vec, svget2_f32(src_vec0, 1), 0.25); + dst_vec = svmla_n_f32_x(pg, dst_vec, svget2_f32(src_vec1, 0), 0.25); + dst_vec = svmla_n_f32_x(pg, dst_vec, svget2_f32(src_vec1, 1), 0.25); + + // Vector store + svst1_f32(pg, dst_ptr, dst_vec); + + // Increment pointers + src_ptr += 2 * num_active; + dst_ptr += num_active; + } + + // Offset pointers + src_ptr += 2 * src_pad + src_stride; + dst_ptr += 2 * dst_pad; + } +} + +void impl_2x2_bilinear_downsample(const float16_t *src, float16_t *dst, + const uint64_t src_height, + const uint64_t src_width, + const uint64_t src_pad, + const uint64_t dst_pad) ARM_STREAMING { + // Compute dst height and width + uint64_t dst_height = src_height / 2; + uint64_t dst_width = src_width / 2; + + // Compute stride + uint64_t src_stride = src_width + 2 * src_pad; + uint64_t dst_stride = dst_width + 2 * dst_pad; + + // Initialise pointers + float16_t *src_ptr = + const_cast(src) + src_pad * src_stride + src_pad; + float16_t *dst_ptr = + const_cast(dst) + dst_pad * dst_stride + dst_pad; + + // Main loop with SVE Intrinsic + uint64_t veclen = svcnth(); + + for (size_t h = 0; h < dst_height; h++) { + for (size_t w = 0; w < dst_width; w += veclen) { + + // Determine active elements of predicate register + svbool_t pg = svwhilelt_b16(w, dst_width); + uint64_t num_active = svcntp_b16(pg, pg); + + // Vector load + svfloat16x2_t src_vec0 = svld2_f16(pg, src_ptr); + svfloat16x2_t src_vec1 = svld2_f16(pg, src_ptr + src_stride); + + // Compute + svfloat16_t dst_vec = svdup_n_f16_x(pg, 0.0); + dst_vec = svmla_n_f16_x(pg, dst_vec, svget2_f16(src_vec0, 0), 0.25); + dst_vec = svmla_n_f16_x(pg, dst_vec, svget2_f16(src_vec0, 1), 0.25); + dst_vec = svmla_n_f16_x(pg, dst_vec, svget2_f16(src_vec1, 0), 0.25); + dst_vec = svmla_n_f16_x(pg, dst_vec, svget2_f16(src_vec1, 1), 0.25); + + // Vector store + svst1_f16(pg, dst_ptr, dst_vec); + + // Increment pointers + src_ptr += 2 * num_active; + dst_ptr += num_active; + } + + // Offset pointers + src_ptr += 2 * src_pad + src_stride; + dst_ptr += 2 * dst_pad; + } +} + +inline svfloat32x2_t lerp_1d(const svbool_t pg, const svfloat32_t p0, + const svfloat32_t p1) ARM_STREAMING { + svfloat32_t o0 = svdup_n_f32_x(pg, 0.0); + o0 = svmla_n_f32_x(pg, o0, p0, 0.75); + o0 = svmla_n_f32_x(pg, o0, p1, 0.25); + svfloat32_t o1 = svdup_n_f32_x(pg, 0.0); + o1 = svmla_n_f32_x(pg, o1, p0, 0.25); + o1 = svmla_n_f32_x(pg, o1, p1, 0.75); + return svcreate2_f32(o0, o1); +} + +inline svfloat32x4_t lerp_2d(const svbool_t pg, const svfloat32_t p00, + const svfloat32_t p01, const svfloat32_t p10, + const svfloat32_t p11) ARM_STREAMING { + svfloat32_t o00 = svdup_n_f32_x(pg, 0.0); + o00 = svmla_n_f32_x(pg, o00, p00, 0.5625); + o00 = svmla_n_f32_x(pg, o00, p01, 0.1875); + o00 = svmla_n_f32_x(pg, o00, p10, 0.1875); + o00 = svmla_n_f32_x(pg, o00, p11, 0.0625); + svfloat32_t o01 = svdup_n_f32_x(pg, 0.0); + o01 = svmla_n_f32_x(pg, o01, p00, 0.1875); + o01 = svmla_n_f32_x(pg, o01, p01, 0.5626); + o01 = svmla_n_f32_x(pg, o01, p10, 0.0625); + o01 = svmla_n_f32_x(pg, o01, p11, 0.1875); + svfloat32_t o10 = svdup_n_f32_x(pg, 0.0); + o10 = svmla_n_f32_x(pg, o10, p00, 0.1875); + o10 = svmla_n_f32_x(pg, o10, p01, 0.0625); + o10 = svmla_n_f32_x(pg, o10, p10, 0.5626); + o10 = svmla_n_f32_x(pg, o10, p11, 0.1875); + svfloat32_t o11 = svdup_n_f32_x(pg, 0.0); + o11 = svmla_n_f32_x(pg, o11, p00, 0.0625); + o11 = svmla_n_f32_x(pg, o11, p01, 0.1875); + o11 = svmla_n_f32_x(pg, o11, p10, 0.1875); + o11 = svmla_n_f32_x(pg, o11, p11, 0.5625); + return svcreate4_f32(o00, o01, o10, o11); +} + +void _process_edge_row_2x2_bilinear_upsample( + float *src_ptr, float *dst_ptr, const uint64_t src_width) ARM_STREAMING { + + for (size_t w = 0; w < src_width - 1; w += svcntw()) { + + // Determine active elements of predicate register + svbool_t pg = svwhilelt_b32(w, src_width - 1); + uint64_t num_active = svcntp_b32(pg, pg); + + // Vector load and linear interpolate + svfloat32_t p0 = svld1_f32(pg, src_ptr); + svfloat32_t p1 = svld1_f32(pg, src_ptr + 1); + svfloat32x2_t src_vec = lerp_1d(pg, p0, p1); + + // Interleave and store + svst2_f32(pg, dst_ptr, src_vec); + + // Increment pointers + src_ptr += num_active; + dst_ptr += 2 * num_active; + } +} + +void impl_2x2_bilinear_upsample(const float *src, float *dst, + const uint64_t src_height, + const uint64_t src_width, + const uint64_t src_pad, + const uint64_t dst_pad) ARM_STREAMING { + + // Compute low resolution height and width + uint64_t dst_height = src_height * 2; + uint64_t dst_width = src_width * 2; + + // Compute stride + uint64_t src_stride = src_width + 2 * src_pad; + uint64_t dst_stride = dst_width + 2 * dst_pad; + + // --- Centre --- + { + float *src_ptr = const_cast(src) + src_pad * src_stride + src_pad; + float *dst_ptr = + const_cast(dst) + (dst_pad + 1) * dst_stride + dst_pad + 1; + + for (size_t h = 0; h < src_height - 1; h++) { + for (size_t w = 0; w < src_width - 1; w += svcntw()) { + + // Determine active elements of predicate register + svbool_t pg = svwhilelt_b32(w, src_width - 1); + uint64_t num_active = svcntp_b32(pg, pg); + + // Vector load and linear interpolate + svfloat32_t p00 = svld1_f32(pg, src_ptr); + svfloat32_t p01 = svld1_f32(pg, src_ptr + 1); + svfloat32_t p10 = svld1_f32(pg, src_ptr + src_stride); + svfloat32_t p11 = svld1_f32(pg, src_ptr + src_stride + 1); + svfloat32x4_t src_vec = lerp_2d(pg, p00, p01, p10, p11); + + // Interleave and store + svst2_f32( + pg, dst_ptr, + svcreate2_f32(svget4_f32(src_vec, 0), svget4_f32(src_vec, 1))); + svst2_f32( + pg, dst_ptr + dst_stride, + svcreate2_f32(svget4_f32(src_vec, 2), svget4_f32(src_vec, 3))); + + // Increment pointers + src_ptr += num_active; + dst_ptr += 2 * num_active; + } + + // Offset pointers + src_ptr += 2 * src_pad + 1; + dst_ptr += 2 * dst_pad + dst_stride + 2; + } + } + + // --- Top Edge --- + { + float *src_ptr = const_cast(src) + src_pad * src_stride + src_pad; + float *dst_ptr = + const_cast(dst) + dst_pad * dst_stride + dst_pad + 1; + _process_edge_row_2x2_bilinear_upsample(src_ptr, dst_ptr, src_width); + } + // --- Bottom Edge --- + { + float *src_ptr = const_cast(src) + + (src_pad + src_height - 1) * src_stride + src_pad; + float *dst_ptr = const_cast(dst) + + (dst_pad + dst_height - 1) * dst_stride + dst_pad + 1; + _process_edge_row_2x2_bilinear_upsample(src_ptr, dst_ptr, src_width); + } + // --- Left Edge / Left Corner --- + { + float *src_ptr = const_cast(src) + src_pad * src_stride + src_pad; + float *dst_ptr = const_cast(dst) + dst_pad * dst_stride + dst_pad; + + dst_ptr[0] = src_ptr[0]; + dst_ptr += dst_stride; + for (size_t h = 0; h < src_height - 1; h++) { + dst_ptr[0] = src_ptr[0] * 0.75 + src_ptr[src_stride] * 0.25; + dst_ptr[dst_stride] = src_ptr[0] * 0.25 + src_ptr[src_stride] * 0.75; + src_ptr += src_stride; + dst_ptr += 2 * dst_stride; + } + dst_ptr[0] = src_ptr[0]; + } + + // --- Right Edge / Right Corner --- + { + float *src_ptr = const_cast(src) + src_pad * src_stride + src_pad + + src_width - 1; + float *dst_ptr = const_cast(dst) + dst_pad * dst_stride + dst_pad + + dst_width - 1; + + dst_ptr[0] = src_ptr[0]; + dst_ptr += dst_stride; + for (size_t h = 0; h < src_height - 1; h++) { + dst_ptr[0] = src_ptr[0] * 0.75 + src_ptr[src_stride] * 0.25; + dst_ptr[dst_stride] = src_ptr[0] * 0.25 + src_ptr[src_stride] * 0.75; + src_ptr += src_stride; + dst_ptr += 2 * dst_stride; + } + dst_ptr[0] = src_ptr[0]; + } +} + +void _process_edge_row_2x2_bilinear_upsample_and_merge( + float *input_ptr, float *input_lr_ptr, float *output_lr_ptr, + float *output_ptr, const uint64_t input_lr_width) ARM_STREAMING { + + for (size_t w = 0; w < input_lr_width - 1; w += svcntw()) { + + // Determine active elements of predicate register + svbool_t pg = svwhilelt_b32(w, input_lr_width - 1); + uint64_t num_active = svcntp_b32(pg, pg); + + // Vector load high resolution + svfloat32x2_t input = svld2_f32(pg, input_ptr); + + // Vector load low resolution and linear interpolate + svfloat32_t p0, p1; + + p0 = svld1_f32(pg, input_lr_ptr); + p1 = svld1_f32(pg, input_lr_ptr + 1); + svfloat32x2_t input_lr_vec = lerp_1d(pg, p0, p1); + + p0 = svld1_f32(pg, output_lr_ptr); + p1 = svld1_f32(pg, output_lr_ptr + 1); + svfloat32x2_t output_lr_vec = lerp_1d(pg, p0, p1); + + // Interleave and merge + svfloat32x2_t output = + svcreate2_f32(svadd_f32_x(pg, + svsub_f32_x(pg, svget2_f32(input, 0), + svget2_f32(input_lr_vec, 0)), + svget2_f32(output_lr_vec, 0)), + svadd_f32_x(pg, + svsub_f32_x(pg, svget2_f32(input, 1), + svget2_f32(input_lr_vec, 1)), + svget2_f32(output_lr_vec, 1))); + + // Vector store + svst2_f32(pg, output_ptr, output); + + // Increment pointers + input_ptr += 2 * num_active; + output_ptr += 2 * num_active; + input_lr_ptr += num_active; + output_lr_ptr += num_active; + } +} + +void impl_2x2_bilinear_upsample_and_merge( + const float *input, const float *input_lr, const float *output_lr, + float *output, const uint64_t input_height, const uint64_t input_width, + const uint64_t input_pad, const uint64_t output_pad) ARM_STREAMING { + + // Compute low resolution height and width + uint64_t input_lr_height = input_height / 2; + uint64_t input_lr_width = input_width / 2; + + // Compute stride + uint64_t input_stride = input_width + 2 * input_pad; + uint64_t input_lr_stride = (input_width / 2) + 2 * input_pad; + uint64_t output_stride = input_width + 2 * output_pad; + + // --- Centre --- + { + float *input_ptr = const_cast(input) + + (input_pad + 1) * input_stride + input_pad + 1; + float *input_lr_ptr = + const_cast(input_lr) + input_pad * input_lr_stride + input_pad; + float *output_lr_ptr = const_cast(output_lr) + + input_pad * input_lr_stride + input_pad; + float *output_ptr = const_cast(output) + + (output_pad + 1) * output_stride + output_pad + 1; + + for (size_t h = 0; h < input_lr_height - 1; h++) { + for (size_t w = 0; w < input_lr_width - 1; w += svcntw()) { + + // Determine active elements of predicate register + svbool_t pg = svwhilelt_b32(w, input_lr_width - 1); + uint64_t num_active = svcntp_b32(pg, pg); + + // Vector load high resolution + svfloat32x2_t input_row0 = svld2_f32(pg, input_ptr); + svfloat32x2_t input_row1 = svld2_f32(pg, input_ptr + input_stride); + + // Vector load low resolution and linear interpolate + svfloat32_t p00, p01, p10, p11; + + p00 = svld1_f32(pg, input_lr_ptr); + p01 = svld1_f32(pg, input_lr_ptr + 1); + p10 = svld1_f32(pg, input_lr_ptr + input_lr_stride); + p11 = svld1_f32(pg, input_lr_ptr + input_lr_stride + 1); + svfloat32x4_t input_lr_vec = lerp_2d(pg, p00, p01, p10, p11); + + p00 = svld1_f32(pg, output_lr_ptr); + p01 = svld1_f32(pg, output_lr_ptr + 1); + p10 = svld1_f32(pg, output_lr_ptr + input_lr_stride); + p11 = svld1_f32(pg, output_lr_ptr + input_lr_stride + 1); + svfloat32x4_t output_lr_vec = lerp_2d(pg, p00, p01, p10, p11); + + // Interleave and merge + svfloat32x2_t output_row0 = + svcreate2_f32(svadd_f32_x(pg, + svsub_f32_x(pg, svget2_f32(input_row0, 0), + svget4_f32(input_lr_vec, 0)), + svget4_f32(output_lr_vec, 0)), + svadd_f32_x(pg, + svsub_f32_x(pg, svget2_f32(input_row0, 1), + svget4_f32(input_lr_vec, 1)), + svget4_f32(output_lr_vec, 1))); + svfloat32x2_t output_row1 = + svcreate2_f32(svadd_f32_x(pg, + svsub_f32_x(pg, svget2_f32(input_row1, 0), + svget4_f32(input_lr_vec, 2)), + svget4_f32(output_lr_vec, 2)), + svadd_f32_x(pg, + svsub_f32_x(pg, svget2_f32(input_row1, 1), + svget4_f32(input_lr_vec, 3)), + svget4_f32(output_lr_vec, 3))); + + // Vector store + svst2_f32(pg, output_ptr, output_row0); + svst2_f32(pg, output_ptr + output_stride, output_row1); + + // Increment pointers + input_ptr += 2 * num_active; + output_ptr += 2 * num_active; + input_lr_ptr += num_active; + output_lr_ptr += num_active; + } + + // Offset pointers + input_ptr += 2 * input_pad + input_stride + 2; + output_ptr += 2 * output_pad + output_stride + 2; + input_lr_ptr += 2 * input_pad + 1; + output_lr_ptr += 2 * input_pad + 1; + } + } + + // --- Top Edge --- + { + float *input_ptr = + const_cast(input) + input_pad * input_stride + input_pad + 1; + float *input_lr_ptr = + const_cast(input_lr) + input_pad * input_lr_stride + input_pad; + float *output_lr_ptr = const_cast(output_lr) + + input_pad * input_lr_stride + input_pad; + float *output_ptr = const_cast(output) + + output_pad * output_stride + output_pad + 1; + _process_edge_row_2x2_bilinear_upsample_and_merge( + input_ptr, input_lr_ptr, output_lr_ptr, output_ptr, input_lr_width); + } + // --- Bottom Edge --- + { + float *input_ptr = const_cast(input) + + (input_pad + input_height - 1) * input_stride + + input_pad + 1; + float *input_lr_ptr = const_cast(input_lr) + + (input_pad + input_lr_height - 1) * input_lr_stride + + input_pad; + float *output_lr_ptr = const_cast(output_lr) + + (input_pad + input_lr_height - 1) * input_lr_stride + + input_pad; + float *output_ptr = const_cast(output) + + (output_pad + input_height - 1) * output_stride + + output_pad + 1; + _process_edge_row_2x2_bilinear_upsample_and_merge( + input_ptr, input_lr_ptr, output_lr_ptr, output_ptr, input_lr_width); + } + // --- Left Edge / Left Corner --- + { + float *input_ptr = + const_cast(input) + input_pad * input_stride + input_pad; + float *input_lr_ptr = + const_cast(input_lr) + input_pad * input_lr_stride + input_pad; + float *output_lr_ptr = const_cast(output_lr) + + input_pad * input_lr_stride + input_pad; + float *output_ptr = + const_cast(output) + output_pad * output_stride + output_pad; + + output_ptr[0] = input_ptr[0] - input_lr_ptr[0] + output_lr_ptr[0]; + input_ptr += input_stride; + output_ptr += output_stride; + for (size_t h = 0; h < input_lr_height - 1; h++) { + output_ptr[0] = + input_ptr[0] - + (input_lr_ptr[0] * 0.75 + input_lr_ptr[input_lr_stride] * 0.25) + + (output_lr_ptr[0] * 0.75 + output_lr_ptr[input_lr_stride] * 0.25); + output_ptr[output_stride] = + input_ptr[input_stride] - + (input_lr_ptr[0] * 0.25 + input_lr_ptr[input_lr_stride] * 0.75) + + (output_lr_ptr[0] * 0.25 + output_lr_ptr[input_lr_stride] * 0.75); + input_ptr += 2 * input_stride; + output_ptr += 2 * output_stride; + input_lr_ptr += input_lr_stride; + output_lr_ptr += input_lr_stride; + } + output_ptr[0] = input_ptr[0] - input_lr_ptr[0] + output_lr_ptr[0]; + } + + // --- Right Edge / Right Corner --- + { + float *input_ptr = const_cast(input) + input_pad * input_stride + + input_pad + input_width - 1; + float *input_lr_ptr = const_cast(input_lr) + + input_pad * input_lr_stride + input_pad + + input_lr_width - 1; + float *output_lr_ptr = const_cast(output_lr) + + input_pad * input_lr_stride + input_pad + + input_lr_width - 1; + float *output_ptr = const_cast(output) + + output_pad * output_stride + output_pad + input_width - + 1; + + output_ptr[0] = input_ptr[0] - input_lr_ptr[0] + output_lr_ptr[0]; + input_ptr += input_stride; + output_ptr += output_stride; + for (size_t h = 0; h < input_lr_height - 1; h++) { + output_ptr[0] = + input_ptr[0] - + (input_lr_ptr[0] * 0.75 + input_lr_ptr[input_lr_stride] * 0.25) + + (output_lr_ptr[0] * 0.75 + output_lr_ptr[input_lr_stride] * 0.25); + output_ptr[output_stride] = + input_ptr[input_stride] - + (input_lr_ptr[0] * 0.25 + input_lr_ptr[input_lr_stride] * 0.75) + + (output_lr_ptr[0] * 0.25 + output_lr_ptr[input_lr_stride] * 0.75); + input_ptr += 2 * input_stride; + output_ptr += 2 * output_stride; + input_lr_ptr += input_lr_stride; + output_lr_ptr += input_lr_stride; + } + output_ptr[0] = input_ptr[0] - input_lr_ptr[0] + output_lr_ptr[0]; + } +} + +inline svfloat16x2_t lerp_1d(const svbool_t pg, const svfloat16_t p0, + const svfloat16_t p1) ARM_STREAMING { + svfloat16_t o0 = svdup_n_f16_x(pg, 0.0); + o0 = svmla_n_f16_x(pg, o0, p0, 0.75); + o0 = svmla_n_f16_x(pg, o0, p1, 0.25); + svfloat16_t o1 = svdup_n_f16_x(pg, 0.0); + o1 = svmla_n_f16_x(pg, o1, p0, 0.25); + o1 = svmla_n_f16_x(pg, o1, p1, 0.75); + return svcreate2_f16(o0, o1); +} + +inline svfloat16x4_t lerp_2d(const svbool_t pg, const svfloat16_t p00, + const svfloat16_t p01, const svfloat16_t p10, + const svfloat16_t p11) ARM_STREAMING { + svfloat16_t o00 = svdup_n_f16_x(pg, 0.0); + o00 = svmla_n_f16_x(pg, o00, p00, 0.5625); + o00 = svmla_n_f16_x(pg, o00, p01, 0.1875); + o00 = svmla_n_f16_x(pg, o00, p10, 0.1875); + o00 = svmla_n_f16_x(pg, o00, p11, 0.0625); + svfloat16_t o01 = svdup_n_f16_x(pg, 0.0); + o01 = svmla_n_f16_x(pg, o01, p00, 0.1875); + o01 = svmla_n_f16_x(pg, o01, p01, 0.5626); + o01 = svmla_n_f16_x(pg, o01, p10, 0.0625); + o01 = svmla_n_f16_x(pg, o01, p11, 0.1875); + svfloat16_t o10 = svdup_n_f16_x(pg, 0.0); + o10 = svmla_n_f16_x(pg, o10, p00, 0.1875); + o10 = svmla_n_f16_x(pg, o10, p01, 0.0625); + o10 = svmla_n_f16_x(pg, o10, p10, 0.5626); + o10 = svmla_n_f16_x(pg, o10, p11, 0.1875); + svfloat16_t o11 = svdup_n_f16_x(pg, 0.0); + o11 = svmla_n_f16_x(pg, o11, p00, 0.0625); + o11 = svmla_n_f16_x(pg, o11, p01, 0.1875); + o11 = svmla_n_f16_x(pg, o11, p10, 0.1875); + o11 = svmla_n_f16_x(pg, o11, p11, 0.5625); + return svcreate4_f16(o00, o01, o10, o11); +} + +void _process_edge_row_2x2_bilinear_upsample( + float16_t *src_ptr, float16_t *dst_ptr, + const uint64_t src_width) ARM_STREAMING { + + for (size_t w = 0; w < src_width - 1; w += svcnth()) { + + // Determine active elements of predicate register + svbool_t pg = svwhilelt_b16(w, src_width - 1); + uint64_t num_active = svcntp_b16(pg, pg); + + // Vector load and linear interpolate + svfloat16_t p0 = svld1_f16(pg, src_ptr); + svfloat16_t p1 = svld1_f16(pg, src_ptr + 1); + svfloat16x2_t src_vec = lerp_1d(pg, p0, p1); + + // Interleave and store + svst2_f16(pg, dst_ptr, src_vec); + + // Increment pointers + src_ptr += num_active; + dst_ptr += 2 * num_active; + } +} + +void impl_2x2_bilinear_upsample(const float16_t *src, float16_t *dst, + const uint64_t src_height, + const uint64_t src_width, + const uint64_t src_pad, + const uint64_t dst_pad) ARM_STREAMING { + + // Compute low resolution height and width + uint64_t dst_height = src_height * 2; + uint64_t dst_width = src_width * 2; + + // Compute stride + uint64_t src_stride = src_width + 2 * src_pad; + uint64_t dst_stride = dst_width + 2 * dst_pad; + + // --- Centre --- + { + float16_t *src_ptr = + const_cast(src) + src_pad * src_stride + src_pad; + float16_t *dst_ptr = + const_cast(dst) + (dst_pad + 1) * dst_stride + dst_pad + 1; + + for (size_t h = 0; h < src_height - 1; h++) { + for (size_t w = 0; w < src_width - 1; w += svcnth()) { + + // Determine active elements of predicate register + svbool_t pg = svwhilelt_b16(w, src_width - 1); + uint64_t num_active = svcntp_b16(pg, pg); + + // Vector load and linear interpolate + svfloat16_t p00 = svld1_f16(pg, src_ptr); + svfloat16_t p01 = svld1_f16(pg, src_ptr + 1); + svfloat16_t p10 = svld1_f16(pg, src_ptr + src_stride); + svfloat16_t p11 = svld1_f16(pg, src_ptr + src_stride + 1); + svfloat16x4_t src_vec = lerp_2d(pg, p00, p01, p10, p11); + + // Interleave and store + svst2_f16( + pg, dst_ptr, + svcreate2_f16(svget4_f16(src_vec, 0), svget4_f16(src_vec, 1))); + svst2_f16( + pg, dst_ptr + dst_stride, + svcreate2_f16(svget4_f16(src_vec, 2), svget4_f16(src_vec, 3))); + + // Increment pointers + src_ptr += num_active; + dst_ptr += 2 * num_active; + } + + // Offset pointers + src_ptr += 2 * src_pad + 1; + dst_ptr += 2 * dst_pad + dst_stride + 2; + } + } + + // --- Top Edge --- + { + float16_t *src_ptr = + const_cast(src) + src_pad * src_stride + src_pad; + float16_t *dst_ptr = + const_cast(dst) + dst_pad * dst_stride + dst_pad + 1; + _process_edge_row_2x2_bilinear_upsample(src_ptr, dst_ptr, src_width); + } + // --- Bottom Edge --- + { + float16_t *src_ptr = const_cast(src) + + (src_pad + src_height - 1) * src_stride + src_pad; + float16_t *dst_ptr = const_cast(dst) + + (dst_pad + dst_height - 1) * dst_stride + dst_pad + 1; + _process_edge_row_2x2_bilinear_upsample(src_ptr, dst_ptr, src_width); + } + // --- Left Edge / Left Corner --- + { + float16_t *src_ptr = + const_cast(src) + src_pad * src_stride + src_pad; + float16_t *dst_ptr = + const_cast(dst) + dst_pad * dst_stride + dst_pad; + + dst_ptr[0] = src_ptr[0]; + dst_ptr += dst_stride; + for (size_t h = 0; h < src_height - 1; h++) { + dst_ptr[0] = src_ptr[0] * 0.75 + src_ptr[src_stride] * 0.25; + dst_ptr[dst_stride] = src_ptr[0] * 0.25 + src_ptr[src_stride] * 0.75; + src_ptr += src_stride; + dst_ptr += 2 * dst_stride; + } + dst_ptr[0] = src_ptr[0]; + } + + // --- Right Edge / Right Corner --- + { + float16_t *src_ptr = const_cast(src) + src_pad * src_stride + + src_pad + src_width - 1; + float16_t *dst_ptr = const_cast(dst) + dst_pad * dst_stride + + dst_pad + dst_width - 1; + + dst_ptr[0] = src_ptr[0]; + dst_ptr += dst_stride; + for (size_t h = 0; h < src_height - 1; h++) { + dst_ptr[0] = src_ptr[0] * 0.75 + src_ptr[src_stride] * 0.25; + dst_ptr[dst_stride] = src_ptr[0] * 0.25 + src_ptr[src_stride] * 0.75; + src_ptr += src_stride; + dst_ptr += 2 * dst_stride; + } + dst_ptr[0] = src_ptr[0]; + } +} + +void _process_edge_row_2x2_bilinear_upsample_and_merge( + float16_t *input_ptr, float16_t *input_lr_ptr, float16_t *output_lr_ptr, + float16_t *output_ptr, const uint64_t input_lr_width) ARM_STREAMING { + + for (size_t w = 0; w < input_lr_width - 1; w += svcnth()) { + + // Determine active elements of predicate register + svbool_t pg = svwhilelt_b16(w, input_lr_width - 1); + uint64_t num_active = svcntp_b16(pg, pg); + + // Vector load high resolution + svfloat16x2_t input = svld2_f16(pg, input_ptr); + + // Vector load low resolution and linear interpolate + svfloat16_t p0, p1; + + p0 = svld1_f16(pg, input_lr_ptr); + p1 = svld1_f16(pg, input_lr_ptr + 1); + svfloat16x2_t input_lr_vec = lerp_1d(pg, p0, p1); + + p0 = svld1_f16(pg, output_lr_ptr); + p1 = svld1_f16(pg, output_lr_ptr + 1); + svfloat16x2_t output_lr_vec = lerp_1d(pg, p0, p1); + + // Interleave and merge + svfloat16x2_t output = + svcreate2_f16(svadd_f16_x(pg, + svsub_f16_x(pg, svget2_f16(input, 0), + svget2_f16(input_lr_vec, 0)), + svget2_f16(output_lr_vec, 0)), + svadd_f16_x(pg, + svsub_f16_x(pg, svget2_f16(input, 1), + svget2_f16(input_lr_vec, 1)), + svget2_f16(output_lr_vec, 1))); + + // Vector store + svst2_f16(pg, output_ptr, output); + + // Increment pointers + input_ptr += 2 * num_active; + output_ptr += 2 * num_active; + input_lr_ptr += num_active; + output_lr_ptr += num_active; + } +} + +void impl_2x2_bilinear_upsample_and_merge( + const float16_t *input, const float16_t *input_lr, + const float16_t *output_lr, float16_t *output, const uint64_t input_height, + const uint64_t input_width, const uint64_t input_pad, + const uint64_t output_pad) ARM_STREAMING { + + // Compute low resolution height and width + uint64_t input_lr_height = input_height / 2; + uint64_t input_lr_width = input_width / 2; + + // Compute stride + uint64_t input_stride = input_width + 2 * input_pad; + uint64_t input_lr_stride = (input_width / 2) + 2 * input_pad; + uint64_t output_stride = input_width + 2 * output_pad; + + // --- Centre --- + { + float16_t *input_ptr = const_cast(input) + + (input_pad + 1) * input_stride + input_pad + 1; + float16_t *input_lr_ptr = const_cast(input_lr) + + input_pad * input_lr_stride + input_pad; + float16_t *output_lr_ptr = const_cast(output_lr) + + input_pad * input_lr_stride + input_pad; + float16_t *output_ptr = const_cast(output) + + (output_pad + 1) * output_stride + output_pad + 1; + + for (size_t h = 0; h < input_lr_height - 1; h++) { + for (size_t w = 0; w < input_lr_width - 1; w += svcnth()) { + + // Determine active elements of predicate register + svbool_t pg = svwhilelt_b16(w, input_lr_width - 1); + uint64_t num_active = svcntp_b16(pg, pg); + + // Vector load high resolution + svfloat16x2_t input_row0 = svld2_f16(pg, input_ptr); + svfloat16x2_t input_row1 = svld2_f16(pg, input_ptr + input_stride); + + // Vector load low resolution and linear interpolate + svfloat16_t p00, p01, p10, p11; + + p00 = svld1_f16(pg, input_lr_ptr); + p01 = svld1_f16(pg, input_lr_ptr + 1); + p10 = svld1_f16(pg, input_lr_ptr + input_lr_stride); + p11 = svld1_f16(pg, input_lr_ptr + input_lr_stride + 1); + svfloat16x4_t input_lr_vec = lerp_2d(pg, p00, p01, p10, p11); + + p00 = svld1_f16(pg, output_lr_ptr); + p01 = svld1_f16(pg, output_lr_ptr + 1); + p10 = svld1_f16(pg, output_lr_ptr + input_lr_stride); + p11 = svld1_f16(pg, output_lr_ptr + input_lr_stride + 1); + svfloat16x4_t output_lr_vec = lerp_2d(pg, p00, p01, p10, p11); + + // Interleave and merge + svfloat16x2_t output_row0 = + svcreate2_f16(svadd_f16_x(pg, + svsub_f16_x(pg, svget2_f16(input_row0, 0), + svget4_f16(input_lr_vec, 0)), + svget4_f16(output_lr_vec, 0)), + svadd_f16_x(pg, + svsub_f16_x(pg, svget2_f16(input_row0, 1), + svget4_f16(input_lr_vec, 1)), + svget4_f16(output_lr_vec, 1))); + svfloat16x2_t output_row1 = + svcreate2_f16(svadd_f16_x(pg, + svsub_f16_x(pg, svget2_f16(input_row1, 0), + svget4_f16(input_lr_vec, 2)), + svget4_f16(output_lr_vec, 2)), + svadd_f16_x(pg, + svsub_f16_x(pg, svget2_f16(input_row1, 1), + svget4_f16(input_lr_vec, 3)), + svget4_f16(output_lr_vec, 3))); + + // Vector store + svst2_f16(pg, output_ptr, output_row0); + svst2_f16(pg, output_ptr + output_stride, output_row1); + + // Increment pointers + input_ptr += 2 * num_active; + output_ptr += 2 * num_active; + input_lr_ptr += num_active; + output_lr_ptr += num_active; + } + + // Offset pointers + input_ptr += 2 * input_pad + input_stride + 2; + output_ptr += 2 * output_pad + output_stride + 2; + input_lr_ptr += 2 * input_pad + 1; + output_lr_ptr += 2 * input_pad + 1; + } + } + + // --- Top Edge --- + { + float16_t *input_ptr = const_cast(input) + + input_pad * input_stride + input_pad + 1; + float16_t *input_lr_ptr = const_cast(input_lr) + + input_pad * input_lr_stride + input_pad; + float16_t *output_lr_ptr = const_cast(output_lr) + + input_pad * input_lr_stride + input_pad; + float16_t *output_ptr = const_cast(output) + + output_pad * output_stride + output_pad + 1; + _process_edge_row_2x2_bilinear_upsample_and_merge( + input_ptr, input_lr_ptr, output_lr_ptr, output_ptr, input_lr_width); + } + // --- Bottom Edge --- + { + float16_t *input_ptr = const_cast(input) + + (input_pad + input_height - 1) * input_stride + + input_pad + 1; + float16_t *input_lr_ptr = + const_cast(input_lr) + + (input_pad + input_lr_height - 1) * input_lr_stride + input_pad; + float16_t *output_lr_ptr = + const_cast(output_lr) + + (input_pad + input_lr_height - 1) * input_lr_stride + input_pad; + float16_t *output_ptr = const_cast(output) + + (output_pad + input_height - 1) * output_stride + + output_pad + 1; + _process_edge_row_2x2_bilinear_upsample_and_merge( + input_ptr, input_lr_ptr, output_lr_ptr, output_ptr, input_lr_width); + } + // --- Left Edge / Left Corner --- + { + float16_t *input_ptr = + const_cast(input) + input_pad * input_stride + input_pad; + float16_t *input_lr_ptr = const_cast(input_lr) + + input_pad * input_lr_stride + input_pad; + float16_t *output_lr_ptr = const_cast(output_lr) + + input_pad * input_lr_stride + input_pad; + float16_t *output_ptr = const_cast(output) + + output_pad * output_stride + output_pad; + + output_ptr[0] = input_ptr[0] - input_lr_ptr[0] + output_lr_ptr[0]; + input_ptr += input_stride; + output_ptr += output_stride; + for (size_t h = 0; h < input_lr_height - 1; h++) { + output_ptr[0] = + input_ptr[0] - + (input_lr_ptr[0] * 0.75 + input_lr_ptr[input_lr_stride] * 0.25) + + (output_lr_ptr[0] * 0.75 + output_lr_ptr[input_lr_stride] * 0.25); + output_ptr[output_stride] = + input_ptr[input_stride] - + (input_lr_ptr[0] * 0.25 + input_lr_ptr[input_lr_stride] * 0.75) + + (output_lr_ptr[0] * 0.25 + output_lr_ptr[input_lr_stride] * 0.75); + input_ptr += 2 * input_stride; + output_ptr += 2 * output_stride; + input_lr_ptr += input_lr_stride; + output_lr_ptr += input_lr_stride; + } + output_ptr[0] = input_ptr[0] - input_lr_ptr[0] + output_lr_ptr[0]; + } + + // --- Right Edge / Right Corner --- + { + float16_t *input_ptr = const_cast(input) + + input_pad * input_stride + input_pad + input_width - + 1; + float16_t *input_lr_ptr = const_cast(input_lr) + + input_pad * input_lr_stride + input_pad + + input_lr_width - 1; + float16_t *output_lr_ptr = const_cast(output_lr) + + input_pad * input_lr_stride + input_pad + + input_lr_width - 1; + float16_t *output_ptr = const_cast(output) + + output_pad * output_stride + output_pad + + input_width - 1; + + output_ptr[0] = input_ptr[0] - input_lr_ptr[0] + output_lr_ptr[0]; + input_ptr += input_stride; + output_ptr += output_stride; + for (size_t h = 0; h < input_lr_height - 1; h++) { + output_ptr[0] = + input_ptr[0] - + (input_lr_ptr[0] * 0.75 + input_lr_ptr[input_lr_stride] * 0.25) + + (output_lr_ptr[0] * 0.75 + output_lr_ptr[input_lr_stride] * 0.25); + output_ptr[output_stride] = + input_ptr[input_stride] - + (input_lr_ptr[0] * 0.25 + input_lr_ptr[input_lr_stride] * 0.75) + + (output_lr_ptr[0] * 0.25 + output_lr_ptr[input_lr_stride] * 0.75); + input_ptr += 2 * input_stride; + output_ptr += 2 * output_stride; + input_lr_ptr += input_lr_stride; + output_lr_ptr += input_lr_stride; + } + output_ptr[0] = input_ptr[0] - input_lr_ptr[0] + output_lr_ptr[0]; + } +} + +void store_za32(const svbool_t pg_true, float *buff, + const size_t C) ARM_STREAMING ARM_INOUT_ZA { + + for (size_t i = 0; i < C; i++) { + svst1_hor_za32(0, i, pg_true, buff); + svst1_hor_za32(1, i, pg_true, buff + svcntw()); + svst1_hor_za32(2, i, pg_true, buff + 2 * svcntw()); + svst1_hor_za32(3, i, pg_true, buff + 3 * svcntw()); + buff += svcntb(); + } +} + +void load_bias_za32(const svbool_t pg_true, float *bias, + const size_t C) ARM_STREAMING ARM_INOUT_ZA { + + for (size_t i = 0; i < C; i += 4) { + svfloat32x4_t b = svcreate4_f32(svdup_n_f32_x(pg_true, bias[i]), + svdup_n_f32_x(pg_true, bias[i + 1]), + svdup_n_f32_x(pg_true, bias[i + 2]), + svdup_n_f32_x(pg_true, bias[i + 3])); + svwrite_hor_za32_f32_vg4(0, i, b); + svwrite_hor_za32_f32_vg4(1, i, b); + svwrite_hor_za32_f32_vg4(2, i, b); + svwrite_hor_za32_f32_vg4(3, i, b); + } +} + +svfloat32x4_t convolve_and_blur_za32( + const svcount_t pc, const svbool_t pg_true, const svbool_t pg_w, float *src, + float *weight, float *weight_blur, float *bias_blur, const size_t C, + const size_t K, const size_t stride) ARM_STREAMING ARM_INOUT_ZA { + + svfloat32_t res0 = svdup_n_f32_x(pg_true, *bias_blur); + svfloat32_t res1 = svdup_n_f32_x(pg_true, *bias_blur); + svfloat32_t res2 = svdup_n_f32_x(pg_true, *bias_blur); + svfloat32_t res3 = svdup_n_f32_x(pg_true, *bias_blur); + + for (size_t ki = 0; ki < K; ki++) { + for (size_t kj = 0; kj < K; kj++) { + + float wblur = *weight_blur; + + svfloat32_t Avec = svld1_f32(pg_w, weight); + svfloat32x4_t Bvec = svld1_f32_x4(pc, src); + + svmopa_za32_f32_m(0, pg_w, pg_true, Avec, svget4_f32(Bvec, 0)); + svmopa_za32_f32_m(1, pg_w, pg_true, Avec, svget4_f32(Bvec, 1)); + svmopa_za32_f32_m(2, pg_w, pg_true, Avec, svget4_f32(Bvec, 2)); + svmopa_za32_f32_m(3, pg_w, pg_true, Avec, svget4_f32(Bvec, 3)); + + res0 = svmla_n_f32_x(pg_true, res0, svget4_f32(Bvec, 0), wblur); + res1 = svmla_n_f32_x(pg_true, res1, svget4_f32(Bvec, 1), wblur); + res2 = svmla_n_f32_x(pg_true, res2, svget4_f32(Bvec, 2), wblur); + res3 = svmla_n_f32_x(pg_true, res3, svget4_f32(Bvec, 3), wblur); + + weight += C; + weight_blur++; + src++; + } + src += stride; + } + + return svcreate4_f32(res0, res1, res2, res3); +} + +void relu_pointwise_convolve_lower_za32( + const svcount_t pc, const svbool_t pg_true, const svbool_t pg_w, float *src, + float *weight, const size_t C) ARM_STREAMING ARM_INOUT_ZA { + + for (size_t c = 0; c < C; c++) { + + svfloat32_t Avec = svld1_f32(pg_w, weight); + svfloat32x4_t Bvec = svld1_f32_x4(pc, src); + + svmopa_za32_f32_m(0, pg_w, pg_true, Avec, + svmax_n_f32_x(pg_true, svget4_f32(Bvec, 0), (float)0.0)); + svmopa_za32_f32_m(1, pg_w, pg_true, Avec, + svmax_n_f32_x(pg_true, svget4_f32(Bvec, 1), (float)0.0)); + svmopa_za32_f32_m(2, pg_w, pg_true, Avec, + svmax_n_f32_x(pg_true, svget4_f32(Bvec, 2), (float)0.0)); + svmopa_za32_f32_m(3, pg_w, pg_true, Avec, + svmax_n_f32_x(pg_true, svget4_f32(Bvec, 3), (float)0.0)); + + weight += C; + src += svcntb(); + } +} + +void relu_multiply_and_reduce_za32(const svbool_t pg_true, + svfloat32x4_t &accumulator, + const size_t C) ARM_STREAMING ARM_INOUT_ZA { + + svfloat32_t res0 = svget4_f32(accumulator, 0); + svfloat32_t res1 = svget4_f32(accumulator, 1); + svfloat32_t res2 = svget4_f32(accumulator, 2); + svfloat32_t res3 = svget4_f32(accumulator, 3); + + for (size_t i0 = 0; i0 < C; i0 += 2) { + + size_t i1 = i0 + C; + + svfloat32x2_t slice0, slice1; + + slice0 = svread_hor_za32_f32_vg2(0, i0); + slice1 = svread_hor_za32_f32_vg2(0, i1); + res0 = + svmla_f32_x(pg_true, res0, + svmax_n_f32_x(pg_true, svget2_f32(slice0, 0), (float)0.0), + svget2_f32(slice1, 0)); + res0 = + svmla_f32_x(pg_true, res0, + svmax_n_f32_x(pg_true, svget2_f32(slice0, 1), (float)0.0), + svget2_f32(slice1, 1)); + + slice0 = svread_hor_za32_f32_vg2(1, i0); + slice1 = svread_hor_za32_f32_vg2(1, i1); + res1 = + svmla_f32_x(pg_true, res1, + svmax_n_f32_x(pg_true, svget2_f32(slice0, 0), (float)0.0), + svget2_f32(slice1, 0)); + res1 = + svmla_f32_x(pg_true, res1, + svmax_n_f32_x(pg_true, svget2_f32(slice0, 1), (float)0.0), + svget2_f32(slice1, 1)); + + slice0 = svread_hor_za32_f32_vg2(2, i0); + slice1 = svread_hor_za32_f32_vg2(2, i1); + res2 = + svmla_f32_x(pg_true, res2, + svmax_n_f32_x(pg_true, svget2_f32(slice0, 0), (float)0.0), + svget2_f32(slice1, 0)); + res2 = + svmla_f32_x(pg_true, res2, + svmax_n_f32_x(pg_true, svget2_f32(slice0, 1), (float)0.0), + svget2_f32(slice1, 1)); + + slice0 = svread_hor_za32_f32_vg2(3, i0); + slice1 = svread_hor_za32_f32_vg2(3, i1); + res3 = + svmla_f32_x(pg_true, res3, + svmax_n_f32_x(pg_true, svget2_f32(slice0, 0), (float)0.0), + svget2_f32(slice1, 0)); + res3 = + svmla_f32_x(pg_true, res3, + svmax_n_f32_x(pg_true, svget2_f32(slice0, 1), (float)0.0), + svget2_f32(slice1, 1)); + } + + accumulator = svcreate4_f32(res0, res1, res2, res3); +} + +ARM_NEW_ZA void impl_collapsenet_channel_branch( + float *src, branch_weights_float32 *weights, float *dst, const size_t H, + const size_t W, const size_t O, const size_t K, const size_t src_pad, + const size_t dst_pad) ARM_STREAMING { + + // Determine vector length + size_t svl = svcntw(); + size_t svl4x = svcntb(); + + // THIS IMPLEMENTATION ONLY WORKS IF OUTPUT CHANNELS ARE LESS THAN THE VECTOR + // LENGTH FOR 32-BIT DATA TYPE + size_t Oh = O / 2; + assert(O <= svl); + + // Create buffer to hold ZA data + float buff[svl * svl * 2]; + + // Define useful predicates + svbool_t pg_true = svptrue_b32(); + svbool_t pg_O = svwhilelt_b32((uint64_t)0, O); + svbool_t pg_Oh = svwhilelt_b32((uint64_t)0, Oh); + + // Offset to account for padding + size_t src_pad_offset = 2 * src_pad; + size_t dst_pad_offset = 2 * dst_pad; + dst += (W + dst_pad_offset + 1) * dst_pad; + + // Determine stride applied to src pointer after kernel edge is reached + size_t stride = W + src_pad_offset - K; + + // Main loop with SME Intrinsic + for (size_t h = 0; h < H; h++) { + for (size_t w = 0; w < W; w += svl4x) { + + // Determine active elements of predicate-as-counter register + svcount_t pc = svwhilelt_c32((uint64_t)w, (uint64_t)W, (uint64_t)4); + uint64_t num_active = std::min(W - w, svl4x); + + // Load bias in ZA tiles + load_bias_za32(pg_true, weights->bias_conv, O); + + // Perform fused convolution and blur + svfloat32x4_t output_tuple = convolve_and_blur_za32( + pc, pg_true, pg_O, src, weights->weight_conv, weights->weight_blur, + weights->bias_blur, O, K, stride); + + // --- Detection Branch --- + + // Buffer Oh rows of ZA tiles + store_za32(pg_true, buff, Oh); + + // Load bias in Oh rows of ZA tiles + load_bias_za32(pg_true, weights->bias_pointwise_conv0, Oh); + + // Pointwise convolution + relu_pointwise_convolve_lower_za32(pc, pg_true, pg_Oh, buff, + weights->weight_pointwise_conv0, Oh); + + // Buffer Oh rows of ZA tiles + store_za32(pg_true, buff, Oh); + + // Load bias in Oh rows of ZA tiles + load_bias_za32(pg_true, weights->bias_pointwise_conv1, Oh); + + // Pointwise convolution + relu_pointwise_convolve_lower_za32(pc, pg_true, pg_Oh, buff, + weights->weight_pointwise_conv1, Oh); + + // --- Merge with Interpolation Branch --- + + // Multiply interpolation and detection branches and accumulate into + // blurred output + relu_multiply_and_reduce_za32(pg_true, output_tuple, Oh); + + // Store output + svst1_f32_x4(pc, dst, output_tuple); + + // Increment pointers + src += num_active; + dst += num_active; + } + + // Offset pointers to account for padding + src += src_pad_offset; + dst += dst_pad_offset; + } +} + +void load_bias_za32(const svbool_t pg_true, float16_t *bias, + const size_t C) ARM_STREAMING ARM_INOUT_ZA { + + for (size_t i = 0; i < C; i += 4) { + svfloat32x4_t b = svcreate4_f32(svdup_n_f32_x(pg_true, (float)bias[i]), + svdup_n_f32_x(pg_true, (float)bias[i + 1]), + svdup_n_f32_x(pg_true, (float)bias[i + 2]), + svdup_n_f32_x(pg_true, (float)bias[i + 3])); + svwrite_hor_za32_f32_vg4(0, i, b); + svwrite_hor_za32_f32_vg4(1, i, b); + svwrite_hor_za32_f32_vg4(2, i, b); + svwrite_hor_za32_f32_vg4(3, i, b); + } +} + +svfloat16x2_t +convolve_and_blur_za32(const svcount_t pc, const svbool_t pg_true, + const svbool_t pg_w, float16_t *src, float16_t *weight, + float16_t *weight_blur, float16_t *bias_blur, + const size_t C, const size_t K, + const size_t stride) ARM_STREAMING ARM_INOUT_ZA { + + svfloat16_t res0 = svdup_n_f16_x(pg_true, *bias_blur); + svfloat16_t res1 = svdup_n_f16_x(pg_true, *bias_blur); + + for (size_t ki = 0; ki < K; ki++) { + for (size_t kj = 0; kj < K; kj += 2) { + + svfloat16_t Avec = svld1_f16(pg_w, weight); + svfloat16x2_t Bvec0 = svld1_f16_x2(pc, src); + svfloat16x2_t Bvec1 = svld1_f16_x2(pc, src + 1); + + svmopa_za32_f16_m(0, pg_w, pg_true, Avec, + svzip1_f16(svget2_f16(Bvec0, 0), svget2_f16(Bvec1, 0))); + svmopa_za32_f16_m(1, pg_w, pg_true, Avec, + svzip2_f16(svget2_f16(Bvec0, 0), svget2_f16(Bvec1, 0))); + svmopa_za32_f16_m(2, pg_w, pg_true, Avec, + svzip1_f16(svget2_f16(Bvec0, 1), svget2_f16(Bvec1, 1))); + svmopa_za32_f16_m(3, pg_w, pg_true, Avec, + svzip2_f16(svget2_f16(Bvec0, 1), svget2_f16(Bvec1, 1))); + + float16_t wblur0 = weight_blur[0]; + float16_t wblur1 = weight_blur[1]; + + res0 = svmla_n_f16_x(pg_true, res0, svget2_f16(Bvec0, 0), wblur0); + res0 = svmla_n_f16_x(pg_true, res0, svget2_f16(Bvec1, 0), wblur1); + res1 = svmla_n_f16_x(pg_true, res1, svget2_f16(Bvec0, 1), wblur0); + res1 = svmla_n_f16_x(pg_true, res1, svget2_f16(Bvec1, 1), wblur1); + + weight += 2 * C; + weight_blur += 2; + src += 2; + } + src += stride; + } + + return svcreate2_f16(res0, res1); +} + +void pointwise_convolve_za32(const svcount_t pc, const svbool_t pg_true, + const svbool_t pg_w, float *src, float16_t *weight, + const size_t C) ARM_STREAMING ARM_INOUT_ZA { + + for (size_t c = 0; c < C; c += 2) { + + svfloat16_t Avec = svld1_f16(pg_w, weight); + svfloat32x4_t Bvec0 = svld1_f32_x4(pc, src); + svfloat32x4_t Bvec1 = svld1_vnum_f32_x4(pc, src, 4); + svfloat16_t Bvec; + + Bvec = svtrn1_f16(svcvt_f16_f32_x(pg_true, svget4_f32(Bvec0, 0)), + svcvt_f16_f32_x(pg_true, svget4_f32(Bvec1, 0))); + svmopa_za32_f16_m(0, pg_w, pg_true, Avec, Bvec); + + Bvec = svtrn1_f16(svcvt_f16_f32_x(pg_true, svget4_f32(Bvec0, 1)), + svcvt_f16_f32_x(pg_true, svget4_f32(Bvec1, 1))); + svmopa_za32_f16_m(1, pg_w, pg_true, Avec, Bvec); + + Bvec = svtrn1_f16(svcvt_f16_f32_x(pg_true, svget4_f32(Bvec0, 2)), + svcvt_f16_f32_x(pg_true, svget4_f32(Bvec1, 2))); + svmopa_za32_f16_m(2, pg_w, pg_true, Avec, Bvec); + + Bvec = svtrn1_f16(svcvt_f16_f32_x(pg_true, svget4_f32(Bvec0, 3)), + svcvt_f16_f32_x(pg_true, svget4_f32(Bvec1, 3))); + svmopa_za32_f16_m(3, pg_w, pg_true, Avec, Bvec); + + weight += 2 * C; + src += 8 * svcntw(); + } +} + +void relu_pointwise_convolve_za32(const svcount_t pc, const svbool_t pg_true, + const svbool_t pg_w, float *src, + float16_t *weight, + const size_t C) ARM_STREAMING ARM_INOUT_ZA { + + for (size_t c = 0; c < C; c += 2) { + + svfloat16_t Avec = svld1_f16(pg_w, weight); + svfloat32x4_t Bvec0 = svld1_f32_x4(pc, src); + svfloat32x4_t Bvec1 = svld1_vnum_f32_x4(pc, src, 4); + svfloat16_t Bvec; + + Bvec = svtrn1_f16(svcvt_f16_f32_x(pg_true, svget4_f32(Bvec0, 0)), + svcvt_f16_f32_x(pg_true, svget4_f32(Bvec1, 0))); + Bvec = svmax_n_f16_x(pg_true, Bvec, (float16_t)0.0); + svmopa_za32_f16_m(0, pg_w, pg_true, Avec, Bvec); + + Bvec = svtrn1_f16(svcvt_f16_f32_x(pg_true, svget4_f32(Bvec0, 1)), + svcvt_f16_f32_x(pg_true, svget4_f32(Bvec1, 1))); + Bvec = svmax_n_f16_x(pg_true, Bvec, (float16_t)0.0); + svmopa_za32_f16_m(1, pg_w, pg_true, Avec, Bvec); + + Bvec = svtrn1_f16(svcvt_f16_f32_x(pg_true, svget4_f32(Bvec0, 2)), + svcvt_f16_f32_x(pg_true, svget4_f32(Bvec1, 2))); + Bvec = svmax_n_f16_x(pg_true, Bvec, (float16_t)0.0); + svmopa_za32_f16_m(2, pg_w, pg_true, Avec, Bvec); + + Bvec = svtrn1_f16(svcvt_f16_f32_x(pg_true, svget4_f32(Bvec0, 3)), + svcvt_f16_f32_x(pg_true, svget4_f32(Bvec1, 3))); + Bvec = svmax_n_f16_x(pg_true, Bvec, (float16_t)0.0); + svmopa_za32_f16_m(3, pg_w, pg_true, Avec, Bvec); + + weight += 2 * C; + src += 8 * svcntw(); + } +} + +ARM_NEW_ZA void impl_collapsenet_channel_branch( + float16_t *src, branch_weights_float16 *weights, float16_t *dst, + const size_t H, const size_t W, const size_t O, const size_t K, + const size_t src_pad, const size_t dst_pad) ARM_STREAMING { + + // Determine vector length + size_t svl = svcntw(); + size_t svl2x = svcnth(); + size_t svl4x = svcntb(); + + // THIS IMPLEMENTATION ONLY WORKS IF OUTPUT CHANNELS ARE LESS THAN THE VECTOR + // LENGTH FOR 32-BIT DATA TYPE + size_t Oh = O / 2; + assert(O <= svl); + + // Create buffer to hold ZA data + float buff[svl * svl * 2]; + + // Define useful predicates + svbool_t pg_true = svptrue_b16(); + svbool_t pg_O = svwhilelt_b16((uint64_t)0, 2 * O); + svbool_t pg_Oh = svwhilelt_b16((uint64_t)0, O); + + // Offset to account for padding + size_t src_pad_offset = 2 * src_pad; + size_t dst_pad_offset = 2 * dst_pad; + dst += (W + dst_pad_offset + 1) * dst_pad; + + // Determine stride applied to src pointer after kernel edge is reached + size_t stride = W + src_pad_offset - (K + 1); + + // Main loop with SME Intrinsic + for (size_t h = 0; h < H; h++) { + for (size_t w = 0; w < W; w += svl4x) { + + // Determine active elements of predicate-as-counter register + svcount_t pc_f16 = svwhilelt_c16((uint64_t)w, (uint64_t)W, (uint64_t)2); + svcount_t pc_f32 = svwhilelt_c32((uint64_t)w, (uint64_t)W, (uint64_t)4); + uint64_t num_active = std::min(W - w, svl4x); + + // Load bias in ZA tiles + load_bias_za32(pg_true, weights->bias_conv, O); + + // Perform fused convolution and blur + svfloat16x2_t output_blur = convolve_and_blur_za32( + pc_f16, pg_true, pg_O, src, weights->weight_conv, + weights->weight_blur, weights->bias_blur, O, K, stride); + + // --- Detection Branch --- + + // Buffer Oh rows of ZA tiles + store_za32(pg_true, buff, Oh); + + // Load bias in Oh rows of ZA tiles + load_bias_za32(pg_true, weights->bias_pointwise_conv0, Oh); + + // Pointwise convolution + relu_pointwise_convolve_za32(pc_f32, pg_true, pg_Oh, buff, + weights->weight_pointwise_conv0, Oh); + + // Buffer Oh rows of ZA tiles + store_za32(pg_true, buff, Oh); + + // Load bias in Oh rows of ZA tiles + load_bias_za32(pg_true, weights->bias_pointwise_conv1, Oh); + + // Pointwise convolution + relu_pointwise_convolve_za32(pc_f32, pg_true, pg_Oh, buff, + weights->weight_pointwise_conv1, Oh); + + // --- Merge with Interpolation Branch --- + + // Multiply interpolation and detection branches and accumulate + svfloat32x4_t output_tuple = svcreate4_f32( + svcvt_f32_f16_x(pg_true, svzip1_f16(svget2_f16(output_blur, 0), + svget2_f16(output_blur, 0))), + svcvt_f32_f16_x(pg_true, svzip2_f16(svget2_f16(output_blur, 0), + svget2_f16(output_blur, 0))), + svcvt_f32_f16_x(pg_true, svzip1_f16(svget2_f16(output_blur, 1), + svget2_f16(output_blur, 1))), + svcvt_f32_f16_x(pg_true, svzip2_f16(svget2_f16(output_blur, 1), + svget2_f16(output_blur, 1)))); + relu_multiply_and_reduce_za32(pg_true, output_tuple, Oh); + + // Cast to FP16 and store output + svfloat16x2_t cast = svcreate2_f16( + svuzp1_f16(svcvt_f16_f32_x(pg_true, svget4_f32(output_tuple, 0)), + svcvt_f16_f32_x(pg_true, svget4_f32(output_tuple, 1))), + svuzp1_f16(svcvt_f16_f32_x(pg_true, svget4_f32(output_tuple, 2)), + svcvt_f16_f32_x(pg_true, svget4_f32(output_tuple, 3)))); + svst1_f16_x2(pc_f16, dst, cast); + + // Increment pointers + src += num_active; + dst += num_active; + } + + // Offset pointers to account for padding + src += src_pad_offset; + dst += dst_pad_offset; + } +} + +void relu_multiply_with_vec_and_reduce_za32( + const svcount_t pc, const svbool_t pg_true, float *src, + svfloat32x4_t &accumulator, const size_t C) ARM_STREAMING ARM_INOUT_ZA { + + svfloat32_t res0 = svget4_f32(accumulator, 0); + svfloat32_t res1 = svget4_f32(accumulator, 1); + svfloat32_t res2 = svget4_f32(accumulator, 2); + svfloat32_t res3 = svget4_f32(accumulator, 3); + + for (size_t i = 0; i < C; i += 2) { + + svfloat32x4_t src_vec0 = svld1_f32_x4(pc, src); + svfloat32x4_t src_vec1 = svld1_vnum_f32_x4(pc, src, 4); + + svfloat32x2_t slice; + + slice = svread_hor_za32_f32_vg2(0, i); + res0 = + svmla_f32_x(pg_true, res0, + svmax_n_f32_x(pg_true, svget4_f32(src_vec0, 0), (float)0.0), + svget2_f32(slice, 0)); + res0 = + svmla_f32_x(pg_true, res0, + svmax_n_f32_x(pg_true, svget4_f32(src_vec1, 0), (float)0.0), + svget2_f32(slice, 1)); + + slice = svread_hor_za32_f32_vg2(1, i); + res1 = + svmla_f32_x(pg_true, res1, + svmax_n_f32_x(pg_true, svget4_f32(src_vec0, 1), (float)0.0), + svget2_f32(slice, 0)); + res1 = + svmla_f32_x(pg_true, res1, + svmax_n_f32_x(pg_true, svget4_f32(src_vec1, 1), (float)0.0), + svget2_f32(slice, 1)); + + slice = svread_hor_za32_f32_vg2(2, i); + res2 = + svmla_f32_x(pg_true, res2, + svmax_n_f32_x(pg_true, svget4_f32(src_vec0, 2), (float)0.0), + svget2_f32(slice, 0)); + res2 = + svmla_f32_x(pg_true, res2, + svmax_n_f32_x(pg_true, svget4_f32(src_vec1, 2), (float)0.0), + svget2_f32(slice, 1)); + + slice = svread_hor_za32_f32_vg2(3, i); + res3 = + svmla_f32_x(pg_true, res3, + svmax_n_f32_x(pg_true, svget4_f32(src_vec0, 3), (float)0.0), + svget2_f32(slice, 0)); + res3 = + svmla_f32_x(pg_true, res3, + svmax_n_f32_x(pg_true, svget4_f32(src_vec1, 3), (float)0.0), + svget2_f32(slice, 1)); + + src += 8 * svcntw(); + } + + accumulator = svcreate4_f32(res0, res1, res2, res3); +} + +ARM_NEW_ZA void impl_collapsenet_channel_branch_interp_unfused( + float16_t *src, branch_weights_interp_unfused_float16 *weights, + float16_t *dst, const size_t H, const size_t W, const size_t O, + const size_t K, const size_t src_pad, const size_t dst_pad) ARM_STREAMING { + + // Determine vector length + size_t svl = svcntw(); + size_t svl2x = svcnth(); + size_t svl4x = svcntb(); + + // THIS IMPLEMENTATION ONLY WORKS IF OUTPUT CHANNELS ARE LESS THAN + // VECTOR LENGTH FOR 32-BIT DATA TYPE + assert(O <= svl); + + // Create buffers to hold ZA data + float buff1[svl * svl * 4]; + float buff2[svl * svl * 4]; + + // Define useful predicates + svbool_t pg_true = svptrue_b16(); + svbool_t pg_O = svwhilelt_b16((uint64_t)0, 2 * O); + + // Offset to account for padding + size_t src_pad_offset = 2 * src_pad; + size_t dst_pad_offset = 2 * dst_pad; + dst += (W + dst_pad_offset + 1) * dst_pad; + + // Determine stride applied to src pointer after kernel edge is reached + size_t stride = W + src_pad_offset - (K + 1); + + // Main loop with SME Intrinsic + for (size_t h = 0; h < H; h++) { + for (size_t w = 0; w < W; w += svl4x) { + + // Determine active elements of predicate-as-counter register + svcount_t pc_f16 = svwhilelt_c16((uint64_t)w, (uint64_t)W, (uint64_t)2); + svcount_t pc_f32 = svwhilelt_c32((uint64_t)w, (uint64_t)W, (uint64_t)4); + uint64_t num_active = std::min(W - w, svl4x); + + // Load bias in ZA tiles + load_bias_za32(pg_true, weights->bias_conv, O); + + // Perform fused convolution and blur + svfloat16x2_t output_blur = convolve_and_blur_za32( + pc_f16, pg_true, pg_O, src, weights->weight_conv, + weights->weight_blur, weights->bias_blur, O, K, stride); + + // Buffer O rows of ZA tiles + store_za32(pg_true, buff1, O); + + // --- Detection Branch --- + + // Load bias in O rows of ZA tiles + load_bias_za32(pg_true, weights->bias_pointwise_conv0, O); + + // Pointwise convolution + relu_pointwise_convolve_za32(pc_f32, pg_true, pg_O, buff1, + weights->weight_pointwise_conv0, O); + + // Buffer O rows of ZA tiles + store_za32(pg_true, buff2, O); + + // Load bias in O rows of ZA tiles + load_bias_za32(pg_true, weights->bias_pointwise_conv1, O); + + // Pointwise convolution + relu_pointwise_convolve_za32(pc_f32, pg_true, pg_O, buff2, + weights->weight_pointwise_conv1, O); + + // Buffer O rows of ZA tiles + store_za32(pg_true, buff2, O); + + // --- Interpolation Branch --- + + // Load bias in O rows of ZA tiles + load_bias_za32(pg_true, weights->bias_interp, O); + + // Pointwise convolution + pointwise_convolve_za32(pc_f32, pg_true, pg_O, buff1, + weights->weight_interp, O); + + // --- Merge Detection and Interpolation Branch --- + + // Multiply interpolation and detection branches and accumulate + svfloat32x4_t output_tuple = svcreate4_f32( + svcvt_f32_f16_x(pg_true, svzip1_f16(svget2_f16(output_blur, 0), + svget2_f16(output_blur, 0))), + svcvt_f32_f16_x(pg_true, svzip2_f16(svget2_f16(output_blur, 0), + svget2_f16(output_blur, 0))), + svcvt_f32_f16_x(pg_true, svzip1_f16(svget2_f16(output_blur, 1), + svget2_f16(output_blur, 1))), + svcvt_f32_f16_x(pg_true, svzip2_f16(svget2_f16(output_blur, 1), + svget2_f16(output_blur, 1)))); + relu_multiply_with_vec_and_reduce_za32(pc_f32, pg_true, buff2, + output_tuple, O); + + // Cast to FP16 and store output + svfloat16x2_t cast = svcreate2_f16( + svuzp1_f16(svcvt_f16_f32_x(pg_true, svget4_f32(output_tuple, 0)), + svcvt_f16_f32_x(pg_true, svget4_f32(output_tuple, 1))), + svuzp1_f16(svcvt_f16_f32_x(pg_true, svget4_f32(output_tuple, 2)), + svcvt_f16_f32_x(pg_true, svget4_f32(output_tuple, 3)))); + svst1_f16_x2(pc_f16, dst, cast); + + // Increment pointers + src += num_active; + dst += num_active; + } + + // Offset pointers to account for padding + src += src_pad_offset; + dst += dst_pad_offset; + } +} + +void relu_multiply_with_vec_and_reduce_scaled_za32( + const svcount_t pc, const svbool_t pg_true, float *src, float16_t *weight, + svfloat32x4_t &accumulator, const size_t C) ARM_STREAMING ARM_INOUT_ZA { + + svfloat32_t res0 = svget4_f32(accumulator, 0); + svfloat32_t res1 = svget4_f32(accumulator, 1); + svfloat32_t res2 = svget4_f32(accumulator, 2); + svfloat32_t res3 = svget4_f32(accumulator, 3); + + for (size_t i = 0; i < C; i += 2) { + + float w0 = weight[i]; + float w1 = weight[i + 1]; + + svfloat32x4_t src_vec0 = svld1_f32_x4(pc, src); + svfloat32x4_t src_vec1 = svld1_vnum_f32_x4(pc, src, 4); + + svfloat32x2_t slice; + + slice = svread_hor_za32_f32_vg2(0, i); + res0 = + svmla_f32_x(pg_true, res0, + svmax_n_f32_x(pg_true, svget4_f32(src_vec0, 0), (float)0.0), + svmul_n_f32_x(pg_true, svget2_f32(slice, 0), w0)); + res0 = + svmla_f32_x(pg_true, res0, + svmax_n_f32_x(pg_true, svget4_f32(src_vec1, 0), (float)0.0), + svmul_n_f32_x(pg_true, svget2_f32(slice, 1), w1)); + + slice = svread_hor_za32_f32_vg2(1, i); + res1 = + svmla_f32_x(pg_true, res1, + svmax_n_f32_x(pg_true, svget4_f32(src_vec0, 1), (float)0.0), + svmul_n_f32_x(pg_true, svget2_f32(slice, 0), w0)); + res1 = + svmla_f32_x(pg_true, res1, + svmax_n_f32_x(pg_true, svget4_f32(src_vec1, 1), (float)0.0), + svmul_n_f32_x(pg_true, svget2_f32(slice, 1), w1)); + + slice = svread_hor_za32_f32_vg2(2, i); + res2 = + svmla_f32_x(pg_true, res2, + svmax_n_f32_x(pg_true, svget4_f32(src_vec0, 2), (float)0.0), + svmul_n_f32_x(pg_true, svget2_f32(slice, 0), w0)); + res2 = + svmla_f32_x(pg_true, res2, + svmax_n_f32_x(pg_true, svget4_f32(src_vec1, 2), (float)0.0), + svmul_n_f32_x(pg_true, svget2_f32(slice, 1), w1)); + + slice = svread_hor_za32_f32_vg2(3, i); + res3 = + svmla_f32_x(pg_true, res3, + svmax_n_f32_x(pg_true, svget4_f32(src_vec0, 3), (float)0.0), + svmul_n_f32_x(pg_true, svget2_f32(slice, 0), w0)); + res3 = + svmla_f32_x(pg_true, res3, + svmax_n_f32_x(pg_true, svget4_f32(src_vec1, 3), (float)0.0), + svmul_n_f32_x(pg_true, svget2_f32(slice, 1), w1)); + + src += 8 * svcntw(); + } + + accumulator = svcreate4_f32(res0, res1, res2, res3); +} + +ARM_NEW_ZA void impl_collapsenet_channel_branch_interp_collect_unfused( + float16_t *src, branch_weights_interp_collect_unfused_float16 *weights, + float16_t *dst, const size_t H, const size_t W, const size_t O, + const size_t K, const size_t src_pad, const size_t dst_pad) ARM_STREAMING { + + // Determine vector length + size_t svl = svcntw(); + size_t svl2x = svcnth(); + size_t svl4x = svcntb(); + + // THIS IMPLEMENTATION ONLY WORKS IF OUTPUT CHANNELS ARE LESS THAN + // VECTOR LENGTH FOR 32-BIT DATA TYPE + assert(O <= svl); + + // Create buffers to hold ZA data + float buff1[svl * svl * 4]; + float buff2[svl * svl * 4]; + + // Define useful predicates + svbool_t pg_true = svptrue_b16(); + svbool_t pg_O = svwhilelt_b16((uint64_t)0, 2 * O); + + // Offset to account for padding + size_t src_pad_offset = 2 * src_pad; + size_t dst_pad_offset = 2 * dst_pad; + dst += (W + dst_pad_offset + 1) * dst_pad; + + // Determine stride applied to src pointer after kernel edge is reached + size_t stride = W + src_pad_offset - (K + 1); + + // Main loop with SME Intrinsic + for (size_t h = 0; h < H; h++) { + for (size_t w = 0; w < W; w += svl4x) { + + // Determine active elements of predicate-as-counter register + svcount_t pc_f16 = svwhilelt_c16((uint64_t)w, (uint64_t)W, (uint64_t)2); + svcount_t pc_f32 = svwhilelt_c32((uint64_t)w, (uint64_t)W, (uint64_t)4); + uint64_t num_active = std::min(W - w, svl4x); + + // Load bias in ZA tiles + load_bias_za32(pg_true, weights->bias_conv, O); + + // Perform fused convolution and blur + svfloat16x2_t output_blur = convolve_and_blur_za32( + pc_f16, pg_true, pg_O, src, weights->weight_conv, + weights->weight_blur, weights->bias_blur, O, K, stride); + + // Buffer O rows of ZA tiles + store_za32(pg_true, buff1, O); + + // --- Detection Branch --- + + // Load bias in O rows of ZA tiles + load_bias_za32(pg_true, weights->bias_pointwise_conv0, O); + + // Pointwise convolution + relu_pointwise_convolve_za32(pc_f32, pg_true, pg_O, buff1, + weights->weight_pointwise_conv0, O); + + // Buffer O rows of ZA tiles + store_za32(pg_true, buff2, O); + + // Load bias in O rows of ZA tiles + load_bias_za32(pg_true, weights->bias_pointwise_conv1, O); + + // Pointwise convolution + relu_pointwise_convolve_za32(pc_f32, pg_true, pg_O, buff2, + weights->weight_pointwise_conv1, O); + + // Buffer O rows of ZA tiles + store_za32(pg_true, buff2, O); + + // --- Interpolation Branch --- + + // Load bias in O rows of ZA tiles + load_bias_za32(pg_true, weights->bias_interp, O); + + // Pointwise convolution + pointwise_convolve_za32(pc_f32, pg_true, pg_O, buff1, + weights->weight_interp, O); + + // --- Merge Detection and Interpolation Branch --- + + // Multiply interpolation and detection branches and accumulate + svfloat32x4_t output_tuple = svcreate4_f32( + svcvt_f32_f16_x(pg_true, svzip1_f16(svget2_f16(output_blur, 0), + svget2_f16(output_blur, 0))), + svcvt_f32_f16_x(pg_true, svzip2_f16(svget2_f16(output_blur, 0), + svget2_f16(output_blur, 0))), + svcvt_f32_f16_x(pg_true, svzip1_f16(svget2_f16(output_blur, 1), + svget2_f16(output_blur, 1))), + svcvt_f32_f16_x(pg_true, svzip2_f16(svget2_f16(output_blur, 1), + svget2_f16(output_blur, 1)))); + relu_multiply_with_vec_and_reduce_scaled_za32( + pc_f32, pg_true, buff2, weights->weight_collect, output_tuple, O); + + // Cast to FP16 and store output + svfloat16x2_t cast = svcreate2_f16( + svuzp1_f16(svcvt_f16_f32_x(pg_true, svget4_f32(output_tuple, 0)), + svcvt_f16_f32_x(pg_true, svget4_f32(output_tuple, 1))), + svuzp1_f16(svcvt_f16_f32_x(pg_true, svget4_f32(output_tuple, 2)), + svcvt_f16_f32_x(pg_true, svget4_f32(output_tuple, 3)))); + svst1_f16_x2(pc_f16, dst, cast); + + // Increment pointers + src += num_active; + dst += num_active; + } + + // Offset pointers to account for padding + src += src_pad_offset; + dst += dst_pad_offset; + } +} + +#ifdef DENOISER_ZA16 + +void store_za16(const svbool_t pg_true, float16_t *buff, + const size_t C) ARM_STREAMING ARM_INOUT_ZA { + + for (size_t i = 0; i < C; i++) { + svst1_hor_za16(0, i, pg_true, buff); + svst1_hor_za16(1, i, pg_true, buff + svcnth()); + buff += svcntb(); + } +} + +void load_bias_za16(const svbool_t pg_true, float16_t *bias, + const size_t C) ARM_STREAMING ARM_INOUT_ZA { + + for (size_t i = 0; i < C; i += 4) { + svfloat16x4_t b = svcreate4_f16(svdup_n_f16_x(pg_true, bias[i]), + svdup_n_f16_x(pg_true, bias[i + 1]), + svdup_n_f16_x(pg_true, bias[i + 2]), + svdup_n_f16_x(pg_true, bias[i + 3])); + svwrite_hor_za16_f16_vg4(0, i, b); + svwrite_hor_za16_f16_vg4(1, i, b); + } +} + +svfloat16x2_t +convolve_and_blur_za16(const svcount_t pc, const svbool_t pg_true, + const svbool_t pg_w, float16_t *src, float16_t *weight, + float16_t *weight_blur, float16_t *bias_blur, + const size_t C, const size_t K, + const size_t stride) ARM_STREAMING ARM_INOUT_ZA { + + svfloat16_t res0 = svdup_n_f16_x(pg_true, *bias_blur); + svfloat16_t res1 = svdup_n_f16_x(pg_true, *bias_blur); + + for (size_t ki = 0; ki < K; ki++) { + for (size_t kj = 0; kj < K; kj++) { + + float16_t wblur = *weight_blur; + + svfloat16_t Avec = svld1_f16(pg_w, weight); + svfloat16x2_t Bvec = svld1_f16_x2(pc, src); + + svmopa_za16_f16_m(0, pg_w, pg_true, Avec, svget2_f16(Bvec, 0)); + svmopa_za16_f16_m(1, pg_w, pg_true, Avec, svget2_f16(Bvec, 1)); + + res0 = svmla_n_f16_x(pg_true, res0, svget2_f16(Bvec, 0), wblur); + res1 = svmla_n_f16_x(pg_true, res1, svget2_f16(Bvec, 1), wblur); + + weight += C; + weight_blur++; + src++; + } + src += stride; + } + + return svcreate2_f16(res0, res1); +} + +void relu_pointwise_convolve_za16(const svcount_t pc, const svbool_t pg_true, + const svbool_t pg_w, float16_t *src, + float16_t *weight, + const size_t C) ARM_STREAMING ARM_INOUT_ZA { + + for (size_t c = 0; c < C; c++) { + + svfloat16_t Avec = svld1_f16(pg_w, weight); + svfloat16x2_t Bvec = svld1_f16_x2(pc, src); + + svmopa_za16_f16_m( + 0, pg_w, pg_true, Avec, + svmax_n_f16_x(pg_true, svget2_f16(Bvec, 0), (float16_t)0.0)); + svmopa_za16_f16_m( + 1, pg_w, pg_true, Avec, + svmax_n_f16_x(pg_true, svget2_f16(Bvec, 1), (float16_t)0.0)); + + weight += C; + src += svcntb(); + } +} + +void relu_multiply_and_reduce_za16(const svbool_t pg_true, + svfloat16x2_t &accumulator, + const size_t C) ARM_STREAMING ARM_INOUT_ZA { + + svfloat16_t res0 = svget2_f16(accumulator, 0); + svfloat16_t res1 = svget2_f16(accumulator, 1); + + for (size_t i0 = 0; i0 < C; i0 += 2) { + + size_t i1 = i0 + C; + + svfloat16x2_t slice0, slice1; + + slice0 = svread_hor_za16_f16_vg2(0, i0); + slice1 = svread_hor_za16_f16_vg2(0, i1); + res0 = svmla_f16_x( + pg_true, res0, + svmax_n_f16_x(pg_true, svget2_f16(slice0, 0), (float16_t)0.0), + svget2_f16(slice1, 0)); + res0 = svmla_f16_x( + pg_true, res0, + svmax_n_f16_x(pg_true, svget2_f16(slice0, 1), (float16_t)0.0), + svget2_f16(slice1, 1)); + + slice0 = svread_hor_za16_f16_vg2(1, i0); + slice1 = svread_hor_za16_f16_vg2(1, i1); + res1 = svmla_f16_x( + pg_true, res1, + svmax_n_f16_x(pg_true, svget2_f16(slice0, 0), (float16_t)0.0), + svget2_f16(slice1, 0)); + res1 = svmla_f16_x( + pg_true, res1, + svmax_n_f16_x(pg_true, svget2_f16(slice0, 1), (float16_t)0.0), + svget2_f16(slice1, 1)); + } + + accumulator = svcreate2_f16(res0, res1); +} + +ARM_NEW_ZA void impl_collapsenet_channel_branch_non_widening( + float16_t *src, branch_weights_float16 *weights, float16_t *dst, + const size_t H, const size_t W, const size_t O, const size_t K, + const size_t src_pad, const size_t dst_pad) ARM_STREAMING { + + // Determine vector length + size_t svl = svcnth(); + size_t svl2x = svcntb(); + + // THIS IMPLEMENTATION ONLY WORKS IF OUTPUT CHANNELS ARE LESS THAN THE VECTOR + // LENGTH FOR 16-BIT DATA TYPE + size_t Oh = O / 2; + assert(O <= svl); + + // Create buffer to hold ZA data + float16_t buff[O * svl]; + + // Define useful predicates + svbool_t pg_true = svptrue_b16(); + svbool_t pg_O = svwhilelt_b16((uint64_t)0, O); + svbool_t pg_Oh = svwhilelt_b16((uint64_t)0, Oh); + + // Offset to account for padding + size_t src_pad_offset = 2 * src_pad; + size_t dst_pad_offset = 2 * dst_pad; + dst += (W + dst_pad_offset + 1) * dst_pad; + + // Determine stride applied to src pointer after kernel edge is reached + size_t stride = W + src_pad_offset - K; + + // Main loop with SME Intrinsic + for (size_t h = 0; h < H; h++) { + for (size_t w = 0; w < W; w += svl2x) { + + // Determine active elements of predicate-as-counter register + svcount_t pc = svwhilelt_c16((uint64_t)w, (uint64_t)W, (uint64_t)2); + uint64_t num_active = std::min(W - w, svl2x); + + // Load bias in ZA tiless + load_bias_za16(pg_true, weights->bias_conv, O); + + // Perform fused convolution and blur + svfloat16x2_t output_tuple = convolve_and_blur_za16( + pc, pg_true, pg_O, src, weights->weight_conv, weights->weight_blur, + weights->bias_blur, O, K, stride); + + // --- Detection Branch --- + + // Buffer Oh rows of ZA tiles + store_za16(pg_true, buff, Oh); + + // Load bias in Oh rows of ZA tiles + load_bias_za16(pg_true, weights->bias_pointwise_conv0, Oh); + + // Pointwise convolution + relu_pointwise_convolve_za16(pc, pg_true, pg_Oh, buff, + weights->weight_pointwise_conv0, Oh); + + // Buffer Oh rows of ZA tiles + store_za16(pg_true, buff, Oh); + + // Load bias in Oh rows of ZA tiles + load_bias_za16(pg_true, weights->bias_pointwise_conv1, Oh); + + // Pointwise convolution + relu_pointwise_convolve_za16(pc, pg_true, pg_Oh, buff, + weights->weight_pointwise_conv1, Oh); + + // --- Merge with Interpolation Branch --- + + // Multiply interpolation and detection branches and accumulate into + // blurred output + relu_multiply_and_reduce_za16(pg_true, output_tuple, Oh); + + // Store output + svst1_f16_x2(pc, dst, output_tuple); + + // Increment pointers + src += num_active; + dst += num_active; + } + + // Offset pointers to account for padding + src += src_pad_offset; + dst += dst_pad_offset; + } +} +#endif + +#ifdef DENOISER_MOP4 +void store_za16_mop4(const svbool_t pg_true, float16_t *buff, const size_t C, + const size_t offset) ARM_STREAMING ARM_INOUT_ZA { + + for (size_t i = 0; i < C; i++) { + svst1_hor_za16(0, i, pg_true, buff); + svst1_hor_za16(0, i + offset, pg_true, buff + svcnth()); + svst1_hor_za16(1, i, pg_true, buff + 2 * svcnth()); + svst1_hor_za16(1, i + offset, pg_true, buff + 3 * svcnth()); + buff += 4 * svcnth(); + } +} + +void load_bias_za16_mop4(const svbool_t pg_true, float16_t *bias, + const size_t C, + const size_t offset) ARM_STREAMING ARM_INOUT_ZA { + + for (size_t i = 0; i < C; i += 4) { + svfloat16x4_t b = svcreate4_f16(svdup_n_f16_x(pg_true, bias[i]), + svdup_n_f16_x(pg_true, bias[i + 1]), + svdup_n_f16_x(pg_true, bias[i + 2]), + svdup_n_f16_x(pg_true, bias[i + 3])); + svwrite_hor_za16_f16_vg4(0, i, b); + svwrite_hor_za16_f16_vg4(0, i + offset, b); + svwrite_hor_za16_f16_vg4(1, i, b); + svwrite_hor_za16_f16_vg4(1, i + offset, b); + } +} + +svfloat16x4_t convolve_and_blur_za16_mop4( + const svcount_t pc, const svbool_t pg_true, const svbool_t pg_w, + float16_t *src, float16_t *weight, float16_t *weight_blur, + float16_t *bias_blur, const size_t C, const size_t K, const size_t stride, + const uint16_t *lut) ARM_STREAMING ARM_INOUT_ZA { + + svfloat16_t res0 = svdup_n_f16_x(pg_true, *bias_blur); + svfloat16_t res1 = svdup_n_f16_x(pg_true, *bias_blur); + svfloat16_t res2 = svdup_n_f16_x(pg_true, *bias_blur); + svfloat16_t res3 = svdup_n_f16_x(pg_true, *bias_blur); + + for (size_t ki = 0; ki < K; ki++) { + for (size_t kj = 0; kj < K; kj++) { + + float16_t wblur = *weight_blur; + + svfloat16_t Avec = + svtbl_f16(svld1_f16(pg_w, weight), svld1_u16(pg_true, lut)); + svfloat16x4_t Bvec = svld1_f16_x4(pc, src); + + svmop4a_1x2_za16_f16_f16( + 0, Avec, svcreate2_f16(svget4_f16(Bvec, 0), svget4_f16(Bvec, 1))); + svmop4a_1x2_za16_f16_f16( + 1, Avec, svcreate2_f16(svget4_f16(Bvec, 2), svget4_f16(Bvec, 3))); + + res0 = svmla_n_f16_x(pg_true, res0, svget4_f16(Bvec, 0), wblur); + res1 = svmla_n_f16_x(pg_true, res1, svget4_f16(Bvec, 1), wblur); + res2 = svmla_n_f16_x(pg_true, res2, svget4_f16(Bvec, 2), wblur); + res3 = svmla_n_f16_x(pg_true, res3, svget4_f16(Bvec, 3), wblur); + + weight += C; + weight_blur++; + src++; + } + src += stride; + } + + return svcreate4_f16(res0, res1, res2, res3); +} + +void relu_pointwise_convolve_za16_mop4( + const svcount_t pc, const svbool_t pg_true, const svbool_t pg_w, + float16_t *src, float16_t *weight, const size_t C, + const uint16_t *lut) ARM_STREAMING ARM_INOUT_ZA { + + for (size_t c = 0; c < C; c++) { + + svfloat16_t Avec = + svtbl_f16(svld1_f16(pg_w, weight), svld1_u16(pg_true, lut)); + svfloat16x4_t Bvec = svld1_f16_x4(pc, src); + + svmop4a_1x2_za16_f16_f16( + 0, Avec, + svcreate2_f16( + svmax_n_f16_x(pg_true, svget4_f16(Bvec, 0), (float16_t)0.0), + svmax_n_f16_x(pg_true, svget4_f16(Bvec, 1), (float16_t)0.0))); + svmop4a_1x2_za16_f16_f16( + 1, Avec, + svcreate2_f16( + svmax_n_f16_x(pg_true, svget4_f16(Bvec, 2), (float16_t)0.0), + svmax_n_f16_x(pg_true, svget4_f16(Bvec, 3), (float16_t)0.0))); + + weight += C; + src += 4 * svcnth(); + } +} + +void relu_multiply_and_reduce_za16_mop4( + const svbool_t pg_true, svfloat16x4_t &accumulator, const size_t C, + const size_t offset) ARM_STREAMING ARM_INOUT_ZA { + + svfloat16_t res0 = svget4_f16(accumulator, 0); + svfloat16_t res1 = svget4_f16(accumulator, 1); + svfloat16_t res2 = svget4_f16(accumulator, 2); + svfloat16_t res3 = svget4_f16(accumulator, 3); + + for (size_t i0 = 0; i0 < C; i0 += 2) { + + size_t i1 = i0 + C; + + svfloat16x2_t slice0, slice1; + + slice0 = svread_hor_za16_f16_vg2(0, i0); + slice1 = svread_hor_za16_f16_vg2(0, i1); + res0 = svmla_f16_x( + pg_true, res0, + svmax_n_f16_x(pg_true, svget2_f16(slice0, 0), (float16_t)0.0), + svget2_f16(slice1, 0)); + res0 = svmla_f16_x( + pg_true, res0, + svmax_n_f16_x(pg_true, svget2_f16(slice0, 1), (float16_t)0.0), + svget2_f16(slice1, 1)); + + slice0 = svread_hor_za16_f16_vg2(0, i0 + offset); + slice1 = svread_hor_za16_f16_vg2(0, i1 + offset); + res1 = svmla_f16_x( + pg_true, res1, + svmax_n_f16_x(pg_true, svget2_f16(slice0, 0), (float16_t)0.0), + svget2_f16(slice1, 0)); + res1 = svmla_f16_x( + pg_true, res1, + svmax_n_f16_x(pg_true, svget2_f16(slice0, 1), (float16_t)0.0), + svget2_f16(slice1, 1)); + + slice0 = svread_hor_za16_f16_vg2(1, i0); + slice1 = svread_hor_za16_f16_vg2(1, i1); + res2 = svmla_f16_x( + pg_true, res2, + svmax_n_f16_x(pg_true, svget2_f16(slice0, 0), (float16_t)0.0), + svget2_f16(slice1, 0)); + res2 = svmla_f16_x( + pg_true, res2, + svmax_n_f16_x(pg_true, svget2_f16(slice0, 1), (float16_t)0.0), + svget2_f16(slice1, 1)); + + slice0 = svread_hor_za16_f16_vg2(1, i0 + offset); + slice1 = svread_hor_za16_f16_vg2(1, i1 + offset); + res3 = svmla_f16_x( + pg_true, res3, + svmax_n_f16_x(pg_true, svget2_f16(slice0, 0), (float16_t)0.0), + svget2_f16(slice1, 0)); + res3 = svmla_f16_x( + pg_true, res3, + svmax_n_f16_x(pg_true, svget2_f16(slice0, 1), (float16_t)0.0), + svget2_f16(slice1, 1)); + } + + accumulator = svcreate4_f16(res0, res1, res2, res3); +} + +ARM_NEW_ZA void impl_collapsenet_channel_branch_non_widening_mop4( + float16_t *src, branch_weights_float16 *weights, float16_t *dst, + const size_t H, const size_t W, const size_t O, const size_t K, + const size_t src_pad, const size_t dst_pad) ARM_STREAMING { + + // Determine vector length + size_t svl = svcnth(); + size_t svl4x = 4 * svl; + + // THIS IMPLEMENTATION ONLY WORKS IF OUTPUT CHANNELS ARE LESS THAN THE VECTOR + // LENGTH FOR 16-BIT DATA TYPE + size_t Oh = O / 2; + assert(2 * O <= svl); + + // Create buffer to hold ZA data + float16_t buff[2 * O * svl]; + + // Define useful predicates + svbool_t pg_true = svptrue_b16(); + svbool_t pg_O = svwhilelt_b16((uint64_t)0, O); + svbool_t pg_Oh = svwhilelt_b16((uint64_t)0, Oh); + + // Initialise lookup table for replicating weights + uint16_t lut[svl]; + for (size_t i = 0; i < 2 * O; i++) + lut[i] = i % O; + + // Offset to account for padding + size_t src_pad_offset = 2 * src_pad; + size_t dst_pad_offset = 2 * dst_pad; + dst += (W + dst_pad_offset + 1) * dst_pad; + + // Determine stride applied to src pointer after kernel edge is reached + size_t stride = W + src_pad_offset - K; + + // Main loop with SME Intrinsic + for (size_t h = 0; h < H; h++) { + for (size_t w = 0; w < W; w += svl4x) { + + // Determine active elements of predicate-as-counter register + svcount_t pc = svwhilelt_c16((uint64_t)w, (uint64_t)W, (uint64_t)4); + uint64_t num_active = std::min(W - w, svl4x); + + // Load bias in ZA tiles + load_bias_za16_mop4(pg_true, weights->bias_conv, O, O); + + // Perform fused convolution and blur + svfloat16x4_t output_tuple = convolve_and_blur_za16_mop4( + pc, pg_true, pg_O, src, weights->weight_conv, weights->weight_blur, + weights->bias_blur, O, K, stride, lut); + + // --- Detection Branch --- + + // Buffer Oh rows of ZA tiles + store_za16_mop4(pg_true, buff, Oh, O); + + // Load bias in Oh rows of ZA tiles + load_bias_za16_mop4(pg_true, weights->bias_pointwise_conv0, Oh, O); + + // Pointwise convolution + relu_pointwise_convolve_za16_mop4( + pc, pg_true, pg_Oh, buff, weights->weight_pointwise_conv0, Oh, lut); + + // Buffer Oh rows of ZA tiles + store_za16_mop4(pg_true, buff, Oh, O); + + // Load bias in Oh rows of ZA tiles + load_bias_za16_mop4(pg_true, weights->bias_pointwise_conv1, Oh, O); + + // Pointwise convolution + relu_pointwise_convolve_za16_mop4( + pc, pg_true, pg_Oh, buff, weights->weight_pointwise_conv1, Oh, lut); + + // --- Merge with Interpolation Branch --- + + // Multiply interpolation and detection branches and accumulate into + // blurred output + relu_multiply_and_reduce_za16_mop4(pg_true, output_tuple, Oh, O); + + // Store output + svst1_f16_x4(pc, dst, output_tuple); + + // Increment pointers + src += num_active; + dst += num_active; + } + + // Offset pointers to account for padding + src += src_pad_offset; + dst += dst_pad_offset; + } +} +#endif + +} // namespace sme2 + +} // namespace acle + +#endif \ No newline at end of file