diff --git a/ethosu/regor/architecture/architecture_constraints.hpp b/ethosu/regor/architecture/architecture_constraints.hpp index 72c20e6682804d0d0c26a639724f91b61d9cffeb..767957309b58bdfaf614fbd38ce7f205928d4654 100644 --- a/ethosu/regor/architecture/architecture_constraints.hpp +++ b/ethosu/regor/architecture/architecture_constraints.hpp @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -101,6 +101,7 @@ public: virtual bool SupportsRescale(DataType fromType, DataType toType) = 0; virtual TransposeSupport SupportsTranspose(OpType opType, TransposeType transposeType) = 0; virtual bool SupportsAccumulatorSaveRestore() = 0; + virtual bool SupportsLeakyRelu(bool quantized, DataType type) = 0; bool CanExecute(const ExecutionQuery &query) { @@ -141,7 +142,6 @@ public: } protected: - virtual bool SupportsLeakyRelu(bool quantized, DataType type) = 0; virtual bool SupportsMatMul(OpType opType) = 0; virtual bool SupportsGather(OpType opType) = 0; virtual bool SupportsScatter(OpType opType) = 0; diff --git a/ethosu/regor/compiler/tflite_graph_optimiser.cpp b/ethosu/regor/compiler/tflite_graph_optimiser.cpp index bd973c1264771ac1ed594f028d00d509ab97fa2d..d4ff52e08360c33912a05c86e02595da582d8a13 100644 --- a/ethosu/regor/compiler/tflite_graph_optimiser.cpp +++ b/ethosu/regor/compiler/tflite_graph_optimiser.cpp @@ -2291,10 +2291,6 @@ Operation *TFLiteGraphOptimiser::ConvertLeakyRelu(Graph *const graph, Operation float alpha = attr->alpha; auto ifm = ifmConn->tensor.get(); auto ofm = ofmConn->tensor.get(); - bool quantScalingInvalidOrUnequal = !IsScalingValidAndEqual(*ifmConn, *ofmConn); - ExecutionQuery query{}; - query.quantScalingInvalidOrUnequal = quantScalingInvalidOrUnequal; - query.ifmType = ifm->Type(); if ( alpha == 0 || std::isinf(1 / alpha) ) { @@ -2318,7 +2314,7 @@ Operation *TFLiteGraphOptimiser::ConvertLeakyRelu(Graph *const graph, Operation returnOp = Convert8bitLeakyReluToLUT(graph, operation, alpha); RecordOptimisation(operation, returnOp); } - else if ( alpha < 0 || isConvertedPrelu || !_constraints->CanExecute(query) ) + else if ( alpha < 0 || isConvertedPrelu || !_constraints->SupportsLeakyRelu(!IsScalingValidAndEqual(*ifmConn, *ofmConn), ifm->Type()) ) { // Use 16-bit lowering to Mul + Max or Min + Mul + Relu + Add returnOp = ConvertLeakyRelu16bit(*ifmConn, *ofmConn, operation);