From 93d1702df1eee74dda17bb38585bcc9398e3d28a Mon Sep 17 00:00:00 2001 From: Johan Gunnarsson Date: Mon, 3 Mar 2025 14:24:25 +0100 Subject: [PATCH] MLBEDSW-9869: Fix rescale fusing to RESIZE RESIZE can do shift only on OFM, and the shift has to be less than 48. In some cases we can support fusing RESCALE with scale if the scale is a power of two by normalizing the scale and shift so that scale is 1. Signed-off-by: Johan Gunnarsson Change-Id: I0b65f3cefcc7fe1cd4c31f5852001ae360682a01 --- .../architecture/ethosu85/ethos_u85_constraints.cpp | 4 +++- .../ethosu85/ethos_u85_register_cs_generator.cpp | 5 +++-- ethosu/regor/common/scaling.cpp | 12 ++++++++++++ ethosu/regor/common/scaling.hpp | 3 ++- ethosu/regor/compiler/graphir_optimiser.cpp | 2 +- 5 files changed, 21 insertions(+), 5 deletions(-) diff --git a/ethosu/regor/architecture/ethosu85/ethos_u85_constraints.cpp b/ethosu/regor/architecture/ethosu85/ethos_u85_constraints.cpp index 338a65b3..535ede5c 100644 --- a/ethosu/regor/architecture/ethosu85/ethos_u85_constraints.cpp +++ b/ethosu/regor/architecture/ethosu85/ethos_u85_constraints.cpp @@ -130,7 +130,9 @@ bool EthosU85Constraints::SupportsFusedRescale(OpType opType, TensorUsage tensor else if ( npuOp == EthosU85NpuOp::Resize && globalScale ) { auto &qs = quantization.scales.front(); - return qs.scale == 1 && qs.shift >= 16; // Only shift of 16 or more supported + // Only shift < 48 supported + const auto normalized = QuantizedScale::ReduceScale(qs); + return normalized.scale == 1 && normalized.shift < 48; } else if ( npuOp == EthosU85NpuOp::Elementwise && globalScale ) { diff --git a/ethosu/regor/architecture/ethosu85/ethos_u85_register_cs_generator.cpp b/ethosu/regor/architecture/ethosu85/ethos_u85_register_cs_generator.cpp index 066d4bb2..6a41e69e 100644 --- a/ethosu/regor/architecture/ethosu85/ethos_u85_register_cs_generator.cpp +++ b/ethosu/regor/architecture/ethosu85/ethos_u85_register_cs_generator.cpp @@ -1906,8 +1906,9 @@ void EthosU85RCSGenerator::GenerateResizeOp(HLCStripe *stripe, MemoryAccesses &m // calculate ifm width read int ifmWidthRead = ((ofmShape.Width() - 1) * scale_w.d + offset_w) / scale_w.n + 2; - // scaling is shift only and + 16 - QuantizedScale ofmScale = op->ofm.quantization.scales[0]; + // scaling is shift only and + 16, so convert to scale 1 and add 16 + const QuantizedScale ofmScale = QuantizedScale::ReduceScale(op->ofm.quantization.scales.front()); + assert(ofmScale.scale == 1); int shift = 16 + ofmScale.shift; // X - width diff --git a/ethosu/regor/common/scaling.cpp b/ethosu/regor/common/scaling.cpp index 1032a102..1e5d4d0f 100644 --- a/ethosu/regor/common/scaling.cpp +++ b/ethosu/regor/common/scaling.cpp @@ -81,6 +81,18 @@ const QuantizedScale &QuantizedScale::Unit() return unitScale; } +QuantizedScale QuantizedScale::ReduceScale(const QuantizedScale &qs) +{ + auto scale = qs.scale; + auto shift = qs.shift; + while ( scale > 1 && (scale & 0x1) == 0 && shift > 0 ) + { + scale >>= 1; + shift--; + } + return {scale, shift}; +} + // Convert int32_t multiplier to int16_t with rounding. int16_t DownScaleInt32ToInt16Multiplier(int32_t multiplier) { diff --git a/ethosu/regor/common/scaling.hpp b/ethosu/regor/common/scaling.hpp index 54fd7229..0b4d0360 100644 --- a/ethosu/regor/common/scaling.hpp +++ b/ethosu/regor/common/scaling.hpp @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2021-2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2021-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -46,6 +46,7 @@ public: * Unit scale, i.e. no scaling */ static const QuantizedScale &Unit(); + static QuantizedScale ReduceScale(const QuantizedScale &qs); }; /* Calculate elementwise Mul OFM QuantizedScale */ diff --git a/ethosu/regor/compiler/graphir_optimiser.cpp b/ethosu/regor/compiler/graphir_optimiser.cpp index 644b55be..7d0c68d8 100644 --- a/ethosu/regor/compiler/graphir_optimiser.cpp +++ b/ethosu/regor/compiler/graphir_optimiser.cpp @@ -901,7 +901,7 @@ Operation *GraphIrOptimiser::FuseRescale(Graph *const graph, Operation *const op { if ( qs.shift > 0 && qs.shift < 31 && (qs.scale % (1 << qs.shift)) == 0 ) { - qs = {(qs.scale >> qs.shift), 0}; + qs = QuantizedScale::ReduceScale(qs); } } return scales; -- GitLab