From ee30e63f32902ec84abcc3ceb2afbb27703cb4a1 Mon Sep 17 00:00:00 2001 From: Michael Platings Date: Thu, 21 Nov 2024 12:35:10 +0000 Subject: [PATCH] Simplify int8-to-float32 conversion Some compilers appear to have difficulty optimizing the previous code so it can impact performance. --- .../kleidicv/conversions/float_conversion.h | 32 +++ kleidicv/src/conversions/float_conv_api.cpp | 42 ++-- kleidicv/src/conversions/float_conv_neon.cpp | 204 +++++++++--------- 3 files changed, 151 insertions(+), 127 deletions(-) create mode 100644 kleidicv/include/kleidicv/conversions/float_conversion.h diff --git a/kleidicv/include/kleidicv/conversions/float_conversion.h b/kleidicv/include/kleidicv/conversions/float_conversion.h new file mode 100644 index 000000000..199a96b3a --- /dev/null +++ b/kleidicv/include/kleidicv/conversions/float_conversion.h @@ -0,0 +1,32 @@ +// SPDX-FileCopyrightText: 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef KLEIDICV_FLOAT_CONVERSION_H +#define KLEIDICV_FLOAT_CONVERSION_H + +#include "kleidicv/kleidicv.h" + +namespace kleidicv { +namespace neon { + +kleidicv_error_t float_conversion_f32_s8(const float *src, size_t src_stride, + int8_t *dst, size_t dst_stride, + size_t width, size_t height); + +kleidicv_error_t float_conversion_f32_u8(const float *src, size_t src_stride, + uint8_t *dst, size_t dst_stride, + size_t width, size_t height); + +kleidicv_error_t float_conversion_s8_f32(const int8_t *src, size_t src_stride, + float *dst, size_t dst_stride, + size_t width, size_t height); + +kleidicv_error_t float_conversion_u8_f32(const uint8_t *src, size_t src_stride, + float *dst, size_t dst_stride, + size_t width, size_t height); + +} // namespace neon +} // namespace kleidicv + +#endif // KLEIDICV_FLOAT_CONVERSION_H diff --git a/kleidicv/src/conversions/float_conv_api.cpp b/kleidicv/src/conversions/float_conv_api.cpp index 0f3fdea2d..4bde2dc11 100644 --- a/kleidicv/src/conversions/float_conv_api.cpp +++ b/kleidicv/src/conversions/float_conv_api.cpp @@ -2,21 +2,13 @@ // // SPDX-License-Identifier: Apache-2.0 +#include "kleidicv/conversions/float_conversion.h" #include "kleidicv/dispatch.h" #include "kleidicv/kleidicv.h" #include "kleidicv/types.h" namespace kleidicv { -namespace neon { - -template -kleidicv_error_t float_conversion(const InputType* src, size_t src_stride, - OutputType* dst, size_t dst_stride, - size_t width, size_t height); - -} // namespace neon - namespace sve2 { template @@ -36,36 +28,34 @@ kleidicv_error_t float_conversion(const InputType* src, size_t src_stride, } // namespace sme2 #ifdef KLEIDICV_HAVE_SVE2 -#define SVE2_FUNC_POINTER(name, itype, otype) \ +#define SVE2_FUNC_POINTER(itype, otype) \ [[maybe_unused]] static auto sve2_func_##itype##_##otype = \ kleidicv::sve2::float_conversion; #else -#define SVE2_FUNC_POINTER(name, itype, otype) +#define SVE2_FUNC_POINTER(itype, otype) #endif // KLEIDICV_HAVE_SVE2 #ifdef KLEIDICV_HAVE_SME2 -#define SME2_FUNC_POINTER(name, itype, otype) \ - static auto sme2_func_##itype##_##otype = \ +#define SME2_FUNC_POINTER(itype, otype) \ + static auto sme2_func_##itype##_##otype = \ kleidicv::sme2::float_conversion; #else -#define SME2_FUNC_POINTER(name, itype, otype) +#define SME2_FUNC_POINTER(itype, otype) #endif // KLEIDICV_HAVE_SME2 // NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables) -#define KLEIDICV_DEFINE_C_API(name, itype, otype) \ - static auto neon_func_##itype##_##otype = \ - kleidicv::neon::float_conversion; \ - SVE2_FUNC_POINTER(name, itype, otype); \ - SME2_FUNC_POINTER(name, itype, otype); \ - KLEIDICV_MULTIVERSION_C_API( \ - name, neon_func_##itype##_##otype, \ - KLEIDICV_SVE2_IMPL_IF(sve2_func_##itype##_##otype), \ +#define KLEIDICV_DEFINE_C_API(partialname, itype, otype) \ + SVE2_FUNC_POINTER(itype, otype); \ + SME2_FUNC_POINTER(itype, otype); \ + KLEIDICV_MULTIVERSION_C_API( \ + kleidicv_##partialname, &kleidicv::neon::partialname, \ + KLEIDICV_SVE2_IMPL_IF(sve2_func_##itype##_##otype), \ sme2_func_##itype##_##otype) // NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables) -KLEIDICV_DEFINE_C_API(kleidicv_float_conversion_f32_s8, float, int8_t); -KLEIDICV_DEFINE_C_API(kleidicv_float_conversion_f32_u8, float, uint8_t); -KLEIDICV_DEFINE_C_API(kleidicv_float_conversion_s8_f32, int8_t, float); -KLEIDICV_DEFINE_C_API(kleidicv_float_conversion_u8_f32, uint8_t, float); +KLEIDICV_DEFINE_C_API(float_conversion_f32_s8, float, int8_t); +KLEIDICV_DEFINE_C_API(float_conversion_f32_u8, float, uint8_t); +KLEIDICV_DEFINE_C_API(float_conversion_s8_f32, int8_t, float); +KLEIDICV_DEFINE_C_API(float_conversion_u8_f32, uint8_t, float); } // namespace kleidicv diff --git a/kleidicv/src/conversions/float_conv_neon.cpp b/kleidicv/src/conversions/float_conv_neon.cpp index 01bddb14c..61a49716b 100644 --- a/kleidicv/src/conversions/float_conv_neon.cpp +++ b/kleidicv/src/conversions/float_conv_neon.cpp @@ -4,7 +4,7 @@ #include -#include "kleidicv/kleidicv.h" +#include "kleidicv/conversions/float_conversion.h" #include "kleidicv/neon.h" namespace kleidicv::neon { @@ -69,97 +69,6 @@ class float_conversion_operation { } }; // end of class float_conversion_operation -template -class float_conversion_operation { - public: - using SrcVecTraits = KLEIDICV_TARGET_NAMESPACE::VecTraits; - using SrcVectorType = typename SrcVecTraits::VectorType; - - float_conversion_operation() : index_{initialize_indexes()} {} - - void process_row(size_t width, Columns src, - Columns dst) { - LoopUnroll{width, SrcVecTraits::num_lanes()} - .unroll_twice([&](size_t step) { - SrcVectorType src0 = vld1q(&src[0]); - SrcVectorType src1 = vld1q(&src[SrcVecTraits::num_lanes()]); - - vector_path(src0, &dst[0]); - vector_path(src1, &dst[SrcVecTraits::num_lanes()]); - - src += ptrdiff_t(step); - dst += ptrdiff_t(step); - }) - .remaining([&](size_t length, size_t) { - for (size_t index = 0; index < length; ++index) { - disable_loop_vectorization(); - dst[ptrdiff_t(index)] = src[ptrdiff_t(index)]; - } - }); - } - - private: - static uint8x16x4_t initialize_indexes() { - if constexpr (std::is_signed_v) { - const uint8x16_t index0 = vcombine_u8(vcreate_u8(0x01ffffff00ffffffULL), - vcreate_u8(0x03ffffff02ffffffULL)); - const uint8x16_t index1 = vcombine_u8(vcreate_u8(0x05ffffff04ffffffULL), - vcreate_u8(0x07ffffff06ffffffULL)); - const uint8x16_t index2 = vcombine_u8(vcreate_u8(0x09ffffff08ffffffULL), - vcreate_u8(0x0bffffff0affffffULL)); - const uint8x16_t index3 = vcombine_u8(vcreate_u8(0x0dffffff0cffffffULL), - vcreate_u8(0x0fffffff0effffffULL)); - return {index0, index1, index2, index3}; - } else { - const uint8x16_t index0 = vcombine_u8(vcreate_u8(0xffffff01ffffff00ULL), - vcreate_u8(0xffffff03ffffff02ULL)); - const uint8x16_t index1 = vcombine_u8(vcreate_u8(0xffffff05ffffff04ULL), - vcreate_u8(0xffffff07ffffff06ULL)); - const uint8x16_t index2 = vcombine_u8(vcreate_u8(0xffffff09ffffff08ULL), - vcreate_u8(0xffffff0bffffff0aULL)); - const uint8x16_t index3 = vcombine_u8(vcreate_u8(0xffffff0dffffff0cULL), - vcreate_u8(0xffffff0fffffff0eULL)); - return {index0, index1, index2, index3}; - } - } - - template < - typename I, - std::enable_if_t && std::is_signed_v, int> = 0> - void vector_path(SrcVectorType src, float* dst) { - int32x4_t a = vreinterpretq_s32_u8(vqtbl1q_u8(src, index_.val[0])); - int32x4_t b = vreinterpretq_s32_u8(vqtbl1q_u8(src, index_.val[1])); - int32x4_t c = vreinterpretq_s32_u8(vqtbl1q_u8(src, index_.val[2])); - int32x4_t d = vreinterpretq_s32_u8(vqtbl1q_u8(src, index_.val[3])); - float32x4x4_t output = { - vcvtq_n_f32_s32(a, 24), - vcvtq_n_f32_s32(b, 24), - vcvtq_n_f32_s32(c, 24), - vcvtq_n_f32_s32(d, 24), - }; - vst1q_f32_x4(dst, output); - } - - template < - typename I, - std::enable_if_t && !std::is_signed_v, int> = 0> - void vector_path(SrcVectorType src, float* dst) { - uint32x4_t a = vreinterpretq_u32_u8(vqtbl1q_u8(src, index_.val[0])); - uint32x4_t b = vreinterpretq_u32_u8(vqtbl1q_u8(src, index_.val[1])); - uint32x4_t c = vreinterpretq_u32_u8(vqtbl1q_u8(src, index_.val[2])); - uint32x4_t d = vreinterpretq_u32_u8(vqtbl1q_u8(src, index_.val[3])); - float32x4x4_t output = { - vcvtq_f32_u32(a), - vcvtq_f32_u32(b), - vcvtq_f32_u32(c), - vcvtq_f32_u32(d), - }; - vst1q_f32_x4(dst, output); - } - - const uint8x16x4_t index_; -}; // end of class float_conversion_operation - template kleidicv_error_t float_conversion(const InputType* src, size_t src_stride, OutputType* dst, size_t dst_stride, @@ -177,15 +86,108 @@ kleidicv_error_t float_conversion(const InputType* src, size_t src_stride, return KLEIDICV_OK; } -#define KLEIDICV_INSTANTIATE_TEMPLATE(itype, otype) \ - template KLEIDICV_TARGET_FN_ATTRS kleidicv_error_t \ - float_conversion(const itype* src, size_t src_stride, \ - otype* dst, size_t dst_stride, size_t width, \ - size_t height) +kleidicv_error_t float_conversion_f32_s8(const float* src, size_t src_stride, + int8_t* dst, size_t dst_stride, + size_t width, size_t height) { + return float_conversion(src, src_stride, dst, dst_stride, width, height); +} + +kleidicv_error_t float_conversion_f32_u8(const float* src, size_t src_stride, + uint8_t* dst, size_t dst_stride, + size_t width, size_t height) { + return float_conversion(src, src_stride, dst, dst_stride, width, height); +} + +KLEIDICV_TARGET_FN_ATTRS kleidicv_error_t +float_conversion_s8_f32(const int8_t* src, size_t src_stride, float* dst, + size_t dst_stride, size_t width, size_t height) { + CHECK_POINTER_AND_STRIDE(src, src_stride, height); + CHECK_POINTER_AND_STRIDE(dst, dst_stride, height); + CHECK_IMAGE_SIZE(width, height); + + // Indices used with the TBL instruction to widen from 8-bit to 32-bit in a + // single instruction. + const uint8x16_t index0 = vcombine_u8(vcreate_u8(0x01ffffff00ffffffULL), + vcreate_u8(0x03ffffff02ffffffULL)); + const uint8x16_t index1 = vcombine_u8(vcreate_u8(0x05ffffff04ffffffULL), + vcreate_u8(0x07ffffff06ffffffULL)); + const uint8x16_t index2 = vcombine_u8(vcreate_u8(0x09ffffff08ffffffULL), + vcreate_u8(0x0bffffff0affffffULL)); + const uint8x16_t index3 = vcombine_u8(vcreate_u8(0x0dffffff0cffffffULL), + vcreate_u8(0x0fffffff0effffffULL)); + for (size_t y = 0; y != height; ++y) { + size_t x = 0; + for (; x + 16 <= width; x += 16) { + int8x16_t input = vld1q(src + x); + // Widen from 8-bit to 32-bit and shift right 24 bits instead of + // sign-extending. + int32x4_t a = vreinterpretq_s32_s8(vqtbl1q_s8(input, index0)); + int32x4_t b = vreinterpretq_s32_s8(vqtbl1q_s8(input, index1)); + int32x4_t c = vreinterpretq_s32_s8(vqtbl1q_s8(input, index2)); + int32x4_t d = vreinterpretq_s32_s8(vqtbl1q_s8(input, index3)); + // Convert to float and divide by 2^24. + float32x4x4_t output = { + vcvtq_n_f32_s32(a, 24), + vcvtq_n_f32_s32(b, 24), + vcvtq_n_f32_s32(c, 24), + vcvtq_n_f32_s32(d, 24), + }; + vst1q_f32_x4(dst + x, output); + } + for (; x != width; ++x) { + disable_loop_vectorization(); + dst[x] = src[x]; + } + + src += src_stride / sizeof(*src); + dst += dst_stride / sizeof(*dst); + } + return KLEIDICV_OK; +} + +KLEIDICV_TARGET_FN_ATTRS kleidicv_error_t +float_conversion_u8_f32(const uint8_t* src, size_t src_stride, float* dst, + size_t dst_stride, size_t width, size_t height) { + CHECK_POINTER_AND_STRIDE(src, src_stride, height); + CHECK_POINTER_AND_STRIDE(dst, dst_stride, height); + CHECK_IMAGE_SIZE(width, height); -KLEIDICV_INSTANTIATE_TEMPLATE(float, int8_t); -KLEIDICV_INSTANTIATE_TEMPLATE(float, uint8_t); -KLEIDICV_INSTANTIATE_TEMPLATE(int8_t, float); -KLEIDICV_INSTANTIATE_TEMPLATE(uint8_t, float); + // Indices used with the TBL instruction to widen from 8-bit to 32-bit in a + // single instruction. + const uint8x16_t index0 = vcombine_u8(vcreate_u8(0xffffff01ffffff00ULL), + vcreate_u8(0xffffff03ffffff02ULL)); + const uint8x16_t index1 = vcombine_u8(vcreate_u8(0xffffff05ffffff04ULL), + vcreate_u8(0xffffff07ffffff06ULL)); + const uint8x16_t index2 = vcombine_u8(vcreate_u8(0xffffff09ffffff08ULL), + vcreate_u8(0xffffff0bffffff0aULL)); + const uint8x16_t index3 = vcombine_u8(vcreate_u8(0xffffff0dffffff0cULL), + vcreate_u8(0xffffff0fffffff0eULL)); + for (size_t y = 0; y != height; ++y) { + size_t x = 0; + for (; x + 16 <= width; x += 16) { + uint8x16_t input = vld1q(src + x); + // Widen from 8-bit to 32-bit + uint32x4_t a = vreinterpretq_u32_u8(vqtbl1q_u8(input, index0)); + uint32x4_t b = vreinterpretq_u32_u8(vqtbl1q_u8(input, index1)); + uint32x4_t c = vreinterpretq_u32_u8(vqtbl1q_u8(input, index2)); + uint32x4_t d = vreinterpretq_u32_u8(vqtbl1q_u8(input, index3)); + float32x4x4_t output = { + vcvtq_f32_u32(a), + vcvtq_f32_u32(b), + vcvtq_f32_u32(c), + vcvtq_f32_u32(d), + }; + vst1q_f32_x4(dst + x, output); + } + for (; x != width; ++x) { + disable_loop_vectorization(); + dst[x] = src[x]; + } + + src += src_stride / sizeof(*src); + dst += dst_stride / sizeof(*dst); + } + return KLEIDICV_OK; +} } // namespace kleidicv::neon -- GitLab