From 1c7d56e0309a858f192926e483c0d1de48de2085 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johan=20Alfv=C3=A9n?= Date: Tue, 27 May 2025 19:57:57 +0200 Subject: [PATCH] MLBEDSW-10872: MLCE: Fix LUT table for RSQRT MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - RSQRT operator caused an output mismatch - Missing branch for handling LUT index 0 - Values near zero (index 0) should map to max output value, consistent with the reference implementation Change-Id: Iba4cc6c49281fc749398b05c2edbdcd70caa34d4 Signed-off-by: Johan Alfvén --- ethosu/regor/compiler/tflite_graph_optimiser.cpp | 11 ++++++++++- ethosu/vela/lut.py | 14 +++++++++----- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/ethosu/regor/compiler/tflite_graph_optimiser.cpp b/ethosu/regor/compiler/tflite_graph_optimiser.cpp index 3aa93dc9..cb069b8c 100644 --- a/ethosu/regor/compiler/tflite_graph_optimiser.cpp +++ b/ethosu/regor/compiler/tflite_graph_optimiser.cpp @@ -2385,7 +2385,16 @@ Operation *TFLiteGraphOptimiser::ConvertRSqrtToLUT(Graph *const graph, Operation for ( int x = qMin + 1; x <= qMax; ++x ) { int index = std::max(0, x - int(zpIn)); - auto value = zpOut + MultiplyByQuantizedMultiplier(kRSqrtLut[index], qScale); + int32_t value; + if ( index == 0 ) + { + // Any value close to 0 (zero index in LUT) is mapped to the max output value + value = qMax; + } + else + { + value = zpOut + MultiplyByQuantizedMultiplier(kRSqrtLut[index], qScale); + } lut.push_back(uint8_t(std::min(qMax, std::max(qMin, int(value))))); } diff --git a/ethosu/vela/lut.py b/ethosu/vela/lut.py index 7b563b08..5c20aedd 100644 --- a/ethosu/vela/lut.py +++ b/ethosu/vela/lut.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2020-2021, 2023-2024 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2020-2021, 2023-2025 Arm Limited and/or its affiliates # # SPDX-License-Identifier: Apache-2.0 # @@ -305,7 +305,6 @@ def create_lut_rsqrt_int8_op(op): quantized_min = min(ix) quantized_max = max(ix) - # Any value close to 0 (zero index in LUT) is mapped to the max output value values = [quantized_max] for x in ix: if x == -128: @@ -313,9 +312,14 @@ def create_lut_rsqrt_int8_op(op): continue # Rsqrt is only defined for positive values x_real = max(0, x - zp_in) - val = RSQRT_LUT[x_real] - val = fp_math.multiply_by_quantized_multiplier(val, output_multiplier, output_shift - kshift) + zp_out - lut_result = min(quantized_max, max(quantized_min, val)) + if x_real == 0: + # Any value close to 0 (zero index in LUT) is mapped to the max output value + lut_result = quantized_max + else: + val = RSQRT_LUT[x_real] + val = fp_math.multiply_by_quantized_multiplier(val, output_multiplier, output_shift - kshift) + zp_out + lut_result = min(quantized_max, max(quantized_min, val)) + values.append(lut_result) return convert_to_lut(op, values, "rsqrt") -- GitLab