diff --git a/kleidicv/include/kleidicv/neon.h b/kleidicv/include/kleidicv/neon.h index 9fc3974ec583eae92ec7b63425d1b9251f3874ac..ecafd7d388ed1cc6adcc3d23a07f9dfab55c8ef5 100644 --- a/kleidicv/include/kleidicv/neon.h +++ b/kleidicv/include/kleidicv/neon.h @@ -170,6 +170,9 @@ class VecTraitsBase : public VectorTypes { return kVectorLength / sizeof(ScalarType); } + // Maximum number of lanes in a vector. + static constexpr size_t max_num_lanes() { return num_lanes(); } + // Loads a single vector from 'src'. static inline void load(const ScalarType *src, VectorType &vec) { vec = vld1q(&src[0]); diff --git a/kleidicv/include/kleidicv/sve2.h b/kleidicv/include/kleidicv/sve2.h index 43e2e40d561a08b0bc172d0d4c5ab3c0788d427b..ebbc5d04e673d17d634c6759b9a4423aa0070805 100644 --- a/kleidicv/include/kleidicv/sve2.h +++ b/kleidicv/include/kleidicv/sve2.h @@ -146,6 +146,11 @@ class VecTraitsBase : public VectorTypes { return static_cast(svcnt()); } + // Maximum number of lanes in a vector. + static constexpr size_t max_num_lanes() KLEIDICV_STREAMING_COMPATIBLE { + return 256 / sizeof(ScalarType); + } + // Loads a single vector from 'src'. static inline void load(Context ctx, const ScalarType *src, VectorType &vec) KLEIDICV_STREAMING_COMPATIBLE { diff --git a/kleidicv/src/arithmetics/sum_api.cpp b/kleidicv/src/arithmetics/sum_api.cpp index 2846cbc83419dbd09df8c357c1cae3e331ea05cc..2959587e80acf002a41c5427a4aa4f23f1e36683 100644 --- a/kleidicv/src/arithmetics/sum_api.cpp +++ b/kleidicv/src/arithmetics/sum_api.cpp @@ -15,7 +15,24 @@ kleidicv_error_t sum(const T *src, size_t src_stride, size_t width, } // namespace neon +namespace sve2 { + +template +kleidicv_error_t sum(const T *src, size_t src_stride, size_t width, + size_t height, T *sum); + +} // namespace sve2 + +namespace sme2 { + +template +kleidicv_error_t sum(const T *src, size_t src_stride, size_t width, + size_t height, T *sum); + +} // namespace sme2 + } // namespace kleidicv KLEIDICV_MULTIVERSION_C_API(kleidicv_sum_f32, &kleidicv::neon::sum, - nullptr, nullptr); + KLEIDICV_SVE2_IMPL_IF(&kleidicv::sve2::sum), + KLEIDICV_SME2_IMPL_IF(&kleidicv::sme2::sum)); diff --git a/kleidicv/src/arithmetics/sum_sc.h b/kleidicv/src/arithmetics/sum_sc.h new file mode 100644 index 0000000000000000000000000000000000000000..8109665cb2a7f1a69a19ec1e900f20b086b8d2c5 --- /dev/null +++ b/kleidicv/src/arithmetics/sum_sc.h @@ -0,0 +1,72 @@ +// SPDX-FileCopyrightText: 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef KLEIDICV_SUM_SC_H +#define KLEIDICV_SUM_SC_H + +#include "kleidicv/kleidicv.h" +#include "kleidicv/sve2.h" +#include "kleidicv/utils.h" + +namespace KLEIDICV_TARGET_NAMESPACE { + +template +class Sum final : public UnrollTwice { + public: + using ContextType = Context; + using VecTraits = KLEIDICV_TARGET_NAMESPACE::VecTraits; + using VectorType = typename VecTraits::VectorType; + + explicit Sum(VectorType &accumulator) KLEIDICV_STREAMING_COMPATIBLE + : accumulator_{accumulator} { + accumulator_ = VecTraits::svdup(0); + } + + void vector_path(ContextType ctx, + VectorType src) KLEIDICV_STREAMING_COMPATIBLE { + accumulator_ = svadd_m(ctx.predicate(), accumulator_, src); + } + + ScalarType get_sum() KLEIDICV_STREAMING_COMPATIBLE { + ScalarType accumulator_final[VecTraits::max_num_lanes()] = {0}; + svst1(VecTraits::svptrue(), accumulator_final, accumulator_); + + ScalarType sum = 0; + for (size_t i = 0; i != VecTraits::num_lanes(); ++i) { + sum += accumulator_final[i]; + } + return sum; + } + + private: + VectorType &accumulator_; +}; + +template +kleidicv_error_t sum_sc(const ScalarType *src, size_t src_stride, size_t width, + size_t height, + ScalarType *sum) KLEIDICV_STREAMING_COMPATIBLE { + using VecTraits = KLEIDICV_TARGET_NAMESPACE::VecTraits; + using VectorType = typename VecTraits::VectorType; + + CHECK_POINTERS(sum); + CHECK_POINTER_AND_STRIDE(src, src_stride, height); + CHECK_IMAGE_SIZE(width, height); + + Rectangle rect{width, height}; + Rows src_rows{src, src_stride}; + + VectorType accumulator; + Sum operation{accumulator}; + + apply_operation_by_rows(operation, rect, src_rows); + + *sum = operation.get_sum(); + + return KLEIDICV_OK; +} + +} // namespace KLEIDICV_TARGET_NAMESPACE + +#endif // KLEIDICV_SUM_SC_H diff --git a/kleidicv/src/arithmetics/sum_sme2.cpp b/kleidicv/src/arithmetics/sum_sme2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..51904798659043e1c44de55b478ff5c4f1f44030 --- /dev/null +++ b/kleidicv/src/arithmetics/sum_sme2.cpp @@ -0,0 +1,22 @@ +// SPDX-FileCopyrightText: 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 + +#include "sum_sc.h" + +namespace kleidicv::sme2 { + +template +KLEIDICV_LOCALLY_STREAMING KLEIDICV_TARGET_FN_ATTRS kleidicv_error_t +sum(const T *src, size_t src_stride, size_t width, size_t height, T *sum) { + return sum_sc(src, src_stride, width, height, sum); +} + +#define KLEIDICV_INSTANTIATE_TEMPLATE(type) \ + template KLEIDICV_TARGET_FN_ATTRS kleidicv_error_t sum( \ + const type *src, size_t src_stride, size_t width, size_t height, \ + type *sum) + +KLEIDICV_INSTANTIATE_TEMPLATE(float); + +} // namespace kleidicv::sme2 diff --git a/kleidicv/src/arithmetics/sum_sve2.cpp b/kleidicv/src/arithmetics/sum_sve2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a1a0a5cf5626dcf1c94cd452a5120e0a9f34f20b --- /dev/null +++ b/kleidicv/src/arithmetics/sum_sve2.cpp @@ -0,0 +1,23 @@ +// SPDX-FileCopyrightText: 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 + +#include "sum_sc.h" + +namespace kleidicv::sve2 { + +template +KLEIDICV_TARGET_FN_ATTRS kleidicv_error_t sum(const T *src, size_t src_stride, + size_t width, size_t height, + T *sum) { + return sum_sc(src, src_stride, width, height, sum); +} + +#define KLEIDICV_INSTANTIATE_TEMPLATE(type) \ + template KLEIDICV_TARGET_FN_ATTRS kleidicv_error_t sum( \ + const type *src, size_t src_stride, size_t width, size_t height, \ + type *sum) + +KLEIDICV_INSTANTIATE_TEMPLATE(float); + +} // namespace kleidicv::sve2