From 0520927e033261ed4ad373a0d1c9e1142003831c Mon Sep 17 00:00:00 2001 From: Mark Horvath Date: Fri, 7 Jun 2024 16:27:56 +0200 Subject: [PATCH] Fix SVE2/SME2 path of exp_f32 for some inputs In case of SVE2/SME2 the order of some multiplications was incorrect and for inputs resulting almost positive infinte values (like the input value of 88.7F) the calculation was faulty. In case of NEON the change is not functional, it just makes easier to the reader to understand the order of multiplications. --- kleidicv/src/arithmetics/exp_neon.cpp | 2 +- kleidicv/src/arithmetics/exp_sc.h | 2 +- test/api/test_exp.cpp | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/kleidicv/src/arithmetics/exp_neon.cpp b/kleidicv/src/arithmetics/exp_neon.cpp index 274dc5ee3..9a4b19ccc 100644 --- a/kleidicv/src/arithmetics/exp_neon.cpp +++ b/kleidicv/src/arithmetics/exp_neon.cpp @@ -60,7 +60,7 @@ class Exp final : public UnrollOnce { float32x4_t s2 = vreinterpretq_f32_u32(e - b); uint32x4_t cmp = vcagtq_f32(n, vdupq_n(192.0F)); float32x4_t r1 = s1 * s1; - float32x4_t r0 = poly * s1 * s2; + float32x4_t r0 = (poly * s1) * s2; return vreinterpretq_f32_u32((cmp & vreinterpretq_u32_f32(r1)) | (~cmp & vreinterpretq_u32_f32(r0))); } diff --git a/kleidicv/src/arithmetics/exp_sc.h b/kleidicv/src/arithmetics/exp_sc.h index fea6b6ad9..41f9f813e 100644 --- a/kleidicv/src/arithmetics/exp_sc.h +++ b/kleidicv/src/arithmetics/exp_sc.h @@ -64,7 +64,7 @@ class Exp final : public UnrollOnce { svfloat32_t s2 = svreinterpret_f32(svsub_x(pg, e, b)); svbool_t cmp = svacgt(pg, n, 192.0F); svfloat32_t r1 = svmul_x(pg, s1, s1); - svfloat32_t r0 = svmul_x(pg, poly, svmul_x(pg, s1, s2)); + svfloat32_t r0 = svmul_x(pg, s2, svmul_x(pg, poly, s1)); return svsel(cmp, r1, r0); } diff --git a/test/api/test_exp.cpp b/test/api/test_exp.cpp index f4f7b2af1..0b87c7aba 100644 --- a/test/api/test_exp.cpp +++ b/test/api/test_exp.cpp @@ -61,7 +61,7 @@ class ExpTestSpecial final : public UnaryOperationTest { static const std::vector& input_values() { static const std::vector kInputValues = { -105.31, -100.07, -81.012, -47.66, -3.1088, -0.21, - 0.7, 6.2, 39.7201, 86.11, 88.947}; + 0.7, 6.2, 39.7201, 86.11, 88.7, 88.947}; return kInputValues; } -- GitLab