From 57d2c6929f5007f213cf6e12933c92c5f1c17412 Mon Sep 17 00:00:00 2001 From: Alexander Bengtsson Date: Fri, 21 Mar 2025 17:51:50 +0100 Subject: [PATCH] MLBEDSW-10579: Add supported-ops check for Softmax - Constraint product of IFM W*H to fit within 16-bits Change-Id: Id9e475271416576005a0d4997b3410155f75d515 Signed-off-by: Alexander Bengtsson --- .../tflite/tflite_supported_operators.cpp | 20 +++++++++++++++++++ .../tflite/tflite_supported_operators.hpp | 1 + 2 files changed, 21 insertions(+) diff --git a/ethosu/regor/tflite/tflite_supported_operators.cpp b/ethosu/regor/tflite/tflite_supported_operators.cpp index 4b030ae9..10643b9f 100644 --- a/ethosu/regor/tflite/tflite_supported_operators.cpp +++ b/ethosu/regor/tflite/tflite_supported_operators.cpp @@ -702,6 +702,25 @@ bool TfLiteSupportedOperators::ConstraintMean(const Operation *op) return true; } +bool TfLiteSupportedOperators::ConstraintSoftmax(const Operation *op) +{ + OpType opType = op->Type(); + if ( opType != OpType::Softmax ) + { + return true; + } + auto ifmConn = op->Input(TensorUsage::IFM); + assert(ifmConn); + static constexpr int maxProd = 1 << 16; + const auto &ifmShape = ifmConn->shape; + if ( ifmShape.ElementsWH() > maxProd ) + { + Failure(op, fmt::format("ifmShape: ({}), W * H = {}", ifmShape.ToString(), ifmShape.ElementsWH()), + "The product of IFM width and height must be less than 65536"); + return false; + } + return true; +} void TfLiteSupportedOperators::Failure(const Operation *op, const std::string &message, const std::string &constraint) { @@ -752,6 +771,7 @@ TfLiteSupportedOperators::TfLiteSupportedOperators(IArchitectureConstraints *con &TfLiteSupportedOperators::ConstraintRsqrt, &TfLiteSupportedOperators::ConstraintConstParams, &TfLiteSupportedOperators::ConstraintMean, + &TfLiteSupportedOperators::ConstraintSoftmax, }; } diff --git a/ethosu/regor/tflite/tflite_supported_operators.hpp b/ethosu/regor/tflite/tflite_supported_operators.hpp index 06b5de79..92552e41 100644 --- a/ethosu/regor/tflite/tflite_supported_operators.hpp +++ b/ethosu/regor/tflite/tflite_supported_operators.hpp @@ -70,5 +70,6 @@ private: bool ConstraintRsqrt(const Operation *op); bool ConstraintConstParams(const Operation *op); bool ConstraintMean(const Operation *op); + bool ConstraintSoftmax(const Operation *op); }; } // namespace regor -- GitLab