From b571999cf5bd387207948f63dbf7f7afa182c3e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johan=20Alfv=C3=A9n?= Date: Mon, 13 Jan 2025 13:12:24 +0100 Subject: [PATCH] MLBEDSW-10238: Fix LeakyRelu int16 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Use correct lowering on Ethos-U55 and Ethos-U65 for int16 LeakyRelu Change-Id: I984e66f7dcd83ce54f2a785bd27493dff2e63ed7 Signed-off-by: Johan Alfvén --- ethosu/regor/architecture/architecture_constraints.hpp | 4 ++-- ethosu/regor/compiler/tflite_graph_optimiser.cpp | 6 +----- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/ethosu/regor/architecture/architecture_constraints.hpp b/ethosu/regor/architecture/architecture_constraints.hpp index 72c20e66..76795730 100644 --- a/ethosu/regor/architecture/architecture_constraints.hpp +++ b/ethosu/regor/architecture/architecture_constraints.hpp @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -101,6 +101,7 @@ public: virtual bool SupportsRescale(DataType fromType, DataType toType) = 0; virtual TransposeSupport SupportsTranspose(OpType opType, TransposeType transposeType) = 0; virtual bool SupportsAccumulatorSaveRestore() = 0; + virtual bool SupportsLeakyRelu(bool quantized, DataType type) = 0; bool CanExecute(const ExecutionQuery &query) { @@ -141,7 +142,6 @@ public: } protected: - virtual bool SupportsLeakyRelu(bool quantized, DataType type) = 0; virtual bool SupportsMatMul(OpType opType) = 0; virtual bool SupportsGather(OpType opType) = 0; virtual bool SupportsScatter(OpType opType) = 0; diff --git a/ethosu/regor/compiler/tflite_graph_optimiser.cpp b/ethosu/regor/compiler/tflite_graph_optimiser.cpp index bd973c12..d4ff52e0 100644 --- a/ethosu/regor/compiler/tflite_graph_optimiser.cpp +++ b/ethosu/regor/compiler/tflite_graph_optimiser.cpp @@ -2291,10 +2291,6 @@ Operation *TFLiteGraphOptimiser::ConvertLeakyRelu(Graph *const graph, Operation float alpha = attr->alpha; auto ifm = ifmConn->tensor.get(); auto ofm = ofmConn->tensor.get(); - bool quantScalingInvalidOrUnequal = !IsScalingValidAndEqual(*ifmConn, *ofmConn); - ExecutionQuery query{}; - query.quantScalingInvalidOrUnequal = quantScalingInvalidOrUnequal; - query.ifmType = ifm->Type(); if ( alpha == 0 || std::isinf(1 / alpha) ) { @@ -2318,7 +2314,7 @@ Operation *TFLiteGraphOptimiser::ConvertLeakyRelu(Graph *const graph, Operation returnOp = Convert8bitLeakyReluToLUT(graph, operation, alpha); RecordOptimisation(operation, returnOp); } - else if ( alpha < 0 || isConvertedPrelu || !_constraints->CanExecute(query) ) + else if ( alpha < 0 || isConvertedPrelu || !_constraints->SupportsLeakyRelu(!IsScalingValidAndEqual(*ifmConn, *ofmConn), ifm->Type()) ) { // Use 16-bit lowering to Mul + Max or Min + Mul + Relu + Add returnOp = ConvertLeakyRelu16bit(*ifmConn, *ofmConn, operation); -- GitLab