From 3931df4849c5024aec808d74a7ed22f897da116b Mon Sep 17 00:00:00 2001 From: Alexander Bengtsson Date: Fri, 14 Mar 2025 23:47:19 +0100 Subject: [PATCH] MLBEDSW-10566: Refactor Mean constraint-checks - Refactor mean constraint checks from TFliteGraphOptimiser to TFLiteSupportedOperators. Change-Id: Ifc420c6db52294a8b1f17ac17dcc713e0b99ded8 Signed-off-by: Alexander Bengtsson --- .../regor/compiler/tflite_graph_optimiser.cpp | 76 ----------- .../test/test_tflite_supported_operators.cpp | 70 +++++++++++ .../tflite/tflite_supported_operators.cpp | 118 +++++++++++++++++- .../tflite/tflite_supported_operators.hpp | 1 + 4 files changed, 187 insertions(+), 78 deletions(-) diff --git a/ethosu/regor/compiler/tflite_graph_optimiser.cpp b/ethosu/regor/compiler/tflite_graph_optimiser.cpp index ec8c6a54..65a9c1e0 100644 --- a/ethosu/regor/compiler/tflite_graph_optimiser.cpp +++ b/ethosu/regor/compiler/tflite_graph_optimiser.cpp @@ -1743,75 +1743,6 @@ Operation *TFLiteGraphOptimiser::ConvertSoftmaxOps(Graph *const graph, Operation return _softmax->ConvertOp(operation); } -static bool MeanOpSupported(Operation *const operation, Shape &reduceAxis, Shape &ifmShape4D) -{ - auto ifmConn = operation->Input(TensorUsage::IFM0); - auto ifm = ifmConn->tensor; - auto axis = operation->Input(TensorUsage::Params)->tensor; - auto axisValues = axis->View().Values(); - auto axisCount = axis->StorageShape().IsEmpty() ? 1 : axis->StorageShape().Depth(); - auto ifmDims = ifmShape4D.Size(); - - // Max kernel size - static constexpr int MAX_MEAN_KERNEL_SIZE = 64 * 64; - // Max size to avoid overflow INT32 - static constexpr int MAX_MEAN_ELEMENTS_INT8 = 2 << 23; // 2²⁴ x 2⁷ = 2³¹ - static constexpr int MAX_MEAN_ELEMENTS_UINT8 = 2 << 22; // 2²³ x 2⁸ = 2³¹ - static constexpr int MAX_MEAN_ELEMENTS_INT16 = 2 << 15; // 2¹⁶ x 2¹⁵ = 2³¹ - - bool supported = false; - - // Compute total number of elements - int elements = 1; - for ( int i = 0; i < ifmDims; ++i ) - { - elements *= reduceAxis[i] ? ifmShape4D[i] : 1; - } - - // Make sure overflow can not occur - switch ( ifm->Type() ) - { - case DataType::Int8: - supported = elements <= MAX_MEAN_ELEMENTS_INT8; - break; - - case DataType::UInt8: - supported = elements <= MAX_MEAN_ELEMENTS_UINT8; - break; - - case DataType::Int16: - supported = elements <= MAX_MEAN_ELEMENTS_INT16; - break; - - default: - supported = false; - break; - } - - // Only support batch 1 - supported = supported && (ifmShape4D.Batch() == 1); - - // Reduced axis must be no greater than MAX_MEAN_KERNEL_SIZE - supported = supported && (reduceAxis.Depth() * ifmShape4D.Depth() <= MAX_MEAN_KERNEL_SIZE); - supported = supported && (reduceAxis.Width() * ifmShape4D.Width() <= MAX_MEAN_KERNEL_SIZE); - supported = supported && (reduceAxis.Height() * ifmShape4D.Height() <= MAX_MEAN_KERNEL_SIZE); - - // Depth is supported if any of h,w,c == 1 - if ( supported && reduceAxis.Depth() ) - { - supported = false; - for ( int i = 1; i < 4; i++ ) - { - if ( ifmShape4D[i] == 1 ) - { - supported = true; - break; - } - } - } - return supported; -} - Operation *TFLiteGraphOptimiser::ConvertMeanOps(Graph *const, Operation *const operation) { auto returnOp = operation; @@ -1838,15 +1769,8 @@ Operation *TFLiteGraphOptimiser::ConvertMeanOps(Graph *const, Operation *const o } // Create a 4D shape to indicate which axis that will be reduced Shape reduceAxis4D = Shape::PadAxes(reduceAxis, 4, 0); - Shape ifmShape4D = Shape::PadAxes(ifmShape, 4, 1); - // Check if it is possible to convert the MEAN - if ( !MeanOpSupported(operation, reduceAxis4D, ifmShape4D) ) - { - return operation; - } - // Fix intermediateShape when keep_dims is false // e.g. IFM=1xHxWxC axis=2 OFM=1xHxC, the intermediateShape should be 1xHx1xC Shape intermediateShape = ofmConn->shape; diff --git a/ethosu/regor/test/test_tflite_supported_operators.cpp b/ethosu/regor/test/test_tflite_supported_operators.cpp index a8c6db60..66a0a5d7 100644 --- a/ethosu/regor/test/test_tflite_supported_operators.cpp +++ b/ethosu/regor/test/test_tflite_supported_operators.cpp @@ -376,6 +376,76 @@ TEST_CASE("Supported operators Common") REQUIRE(supportedOps->Check(op.get()) == true); op->Disconnect(); } + + SECTION("ConstraintMean") + { + { + // Supported mean + auto op = CreateOperation(OpType::Mean, Shape(1, 10, 10, 1), DataType::Int8, Shape(1, 1, 10, 1), DataType::Int8); + auto params = CreateTensor("axis", Shape(1), DataType::Int32, std::vector{1}); + op->ConnectInput(TensorUsage::Params, params); + REQUIRE(supportedOps->Check(op.get()) == true); + op->Disconnect(); + } + { + // Batch > 1 is unsupported + auto op = CreateOperation(OpType::Mean, Shape(2, 10, 10, 1), DataType::Int8, Shape(2, 1, 10, 1), DataType::Int8); + auto params = CreateTensor("axis", Shape(1), DataType::Int32, std::vector{1}); + op->ConnectInput(TensorUsage::Params, params); + REQUIRE(supportedOps->Check(op.get()) == false); + op->Disconnect(); + } + { + // Reduced depth only supported if any of H,W,C is 1 + auto op = CreateOperation(OpType::Mean, Shape(1, 2, 10, 5), DataType::Int8, Shape(1, 2, 10, 1), DataType::Int8); + auto params = CreateTensor("axis", Shape(1), DataType::Int32, std::vector{3}); + op->ConnectInput(TensorUsage::Params, params); + REQUIRE(supportedOps->Check(op.get()) == false); + // change height to 1 and validate pass + auto ifmConn = op->Input(TensorUsage::IFM); + auto ofmConn = op->Output(TensorUsage::OFM); + ifmConn->shape = ifmConn->shape.WithHeight(1); + ofmConn->shape = ofmConn->shape.WithHeight(1); + REQUIRE(supportedOps->Check(op.get()) == true); + op->Disconnect(); + } + { + // Kernel_size must not be greater than 64 * 64 + auto op = CreateOperation(OpType::Mean, Shape(1, 64 * 64 + 1, 10, 5), DataType::Int8, Shape(1, 1, 10, 5), DataType::Int8); + auto params = CreateTensor("axis", Shape(1), DataType::Int32, std::vector{1}); + op->ConnectInput(TensorUsage::Params, params); + REQUIRE(supportedOps->Check(op.get()) == false); + // change ifm height to 64*64 and validate pass + auto ifmConn = op->Input(TensorUsage::IFM); + ifmConn->shape = ifmConn->shape.WithHeight(64 * 64); + REQUIRE(supportedOps->Check(op.get()) == true); + op->Disconnect(); + } + { + // max reduced elements uint8 (2²³) + auto op = CreateOperation(OpType::Mean, Shape(1, 1 << 12, 1 << 11, 1), DataType::UInt8, Shape(1, 1, 1, 1), DataType::UInt8); + auto params = CreateTensor("axis", Shape(2), DataType::Int32, std::vector{1, 2}); + op->ConnectInput(TensorUsage::Params, params); + REQUIRE(supportedOps->Check(op.get()) == true); + auto ifmConn = op->Input(TensorUsage::IFM); + // increase height and validate failure + ifmConn->shape = ifmConn->shape.WithHeight(ifmConn->shape.Height() + 1); + REQUIRE(supportedOps->Check(op.get()) == false); + op->Disconnect(); + } + { + // max reduced elements int16 (2¹⁶) + auto op = CreateOperation(OpType::Mean, Shape(1, 1 << 11, 1 << 5, 1), DataType::Int16, Shape(1, 1, 1, 1), DataType::Int16); + auto params = CreateTensor("axis", Shape(2), DataType::Int32, std::vector{1, 2}); + op->ConnectInput(TensorUsage::Params, params); + REQUIRE(supportedOps->Check(op.get()) == true); + auto ifmConn = op->Input(TensorUsage::IFM); + // increase height and validate failure + ifmConn->shape = ifmConn->shape.WithHeight(ifmConn->shape.Height() + 1); + REQUIRE(supportedOps->Check(op.get()) == false); + op->Disconnect(); + } + } } TEST_CASE("Supported operators EthosU55") diff --git a/ethosu/regor/tflite/tflite_supported_operators.cpp b/ethosu/regor/tflite/tflite_supported_operators.cpp index a27463f5..cda0dad1 100644 --- a/ethosu/regor/tflite/tflite_supported_operators.cpp +++ b/ethosu/regor/tflite/tflite_supported_operators.cpp @@ -572,9 +572,13 @@ bool TfLiteSupportedOperators::ConstraintRsqrt(const Operation *op) bool TfLiteSupportedOperators::ConstraintConstParams(const Operation *op) { OpType opType = op->Type(); - if ( opType != OpType::Slice ) + switch ( opType ) { - return true; + case OpType::Slice: + case OpType::Mean: + break; + default: + return true; } for ( const auto item : op->Inputs().pairs() ) @@ -592,6 +596,115 @@ bool TfLiteSupportedOperators::ConstraintConstParams(const Operation *op) return true; } +bool TfLiteSupportedOperators::ConstraintMean(const Operation *op) +{ + OpType opType = op->Type(); + if ( opType != OpType::Mean ) + { + return true; + } + static constexpr int MAX_MEAN_KERNEL_SIZE = 64 * 64; + static constexpr int MAX_MEAN_ELEMENTS_INT8 = 1 << 24; // 2²⁴ x 2⁷ = 2³¹ + static constexpr int MAX_MEAN_ELEMENTS_UINT8 = 1 << 23; // 2²³ x 2⁸ = 2³¹ + static constexpr int MAX_MEAN_ELEMENTS_INT16 = 1 << 16; // 2¹⁶ x 2¹⁵ = 2³¹ + auto ifmConn = op->Input(TensorUsage::IFM); + auto params = op->Input(TensorUsage::Params); + assert(ifmConn); + assert(params); + auto ifmShape = ifmConn->shape; + auto axisTens = params->tensor; + auto axisCount = axisTens->StorageShape().IsEmpty() ? 1 : axisTens->StorageShape().Depth(); + auto axisValues = axisTens->View().Values(); + + auto axisMask = ifmShape.WithZeros(); + for ( int i = 0; i < axisCount; i++ ) + { + axisMask[axisValues[i]] = 1; + } + + axisMask = Shape::PadAxes(axisMask, 4, 0); + Shape ifmShape4D = Shape::PadAxes(ifmShape, 4, 1); + + auto ifmType = ifmConn->tensor->Type(); + + // Constrain IFM-Batch to 1 + if ( ifmShape4D.Batch() > 1 ) + { + Failure(op, fmt::format("Batch > 1: {}", ifmShape4D.ToString()), "Batch > 1 is not supported"); + return false; + } + + // Reduced depth is only supported if any of IFM H,W,C is 1 + if ( axisMask.Depth() ) + { + bool supported = false; + for ( int i = 1; i < 4; i++ ) + { + if ( ifmShape4D[i] == 1 ) + { + supported = true; + break; + } + } + if ( !supported ) + { + Failure(op, fmt::format("Unsupported depth-reduction. IFM: {}", ifmShape4D.ToString()), "Depth is only supported if any of h,w,c == 1"); + return false; + } + } + + // Reduced axes are represented with their IFM-value + // Non reduced axes are represented by 0 + // e.g. IFM (5,8,7,9) with axis=H,C -> (0,8,0,9) + Shape reducedAxes = ifmShape4D * axisMask; + // Constrain kernel-size + if ( reducedAxes.GreaterMask(Shape(nullptr, 4, MAX_MEAN_KERNEL_SIZE)) != 0 ) + { + static const std::string constraint = fmt::format("Reduced axis must be less than {}", MAX_MEAN_KERNEL_SIZE); + Failure(op, "Reduced axis is too large", constraint); + return false; + } + + // Constrain reduced elements + int elements = 1; + for ( int i = 0; i < axisMask.Size(); i++ ) + { + elements *= axisMask[i] ? ifmShape4D[i] : 1; + } + switch ( ifmConn->tensor->Type() ) + { + case DataType::Int8: + if ( elements > MAX_MEAN_ELEMENTS_INT8 ) + { + static const std::string constraint = fmt::format("max elements (int8) = {}", MAX_MEAN_ELEMENTS_INT8); + Failure(op, fmt::format("Too many reduced elements: {}", elements), constraint); + return false; + } + break; + case DataType::UInt8: + if ( elements > MAX_MEAN_ELEMENTS_UINT8 ) + { + static const std::string constraint = fmt::format("max elements (uint8) = {}", MAX_MEAN_ELEMENTS_UINT8); + Failure(op, fmt::format("Too many reduced elements: {}", elements), constraint); + return false; + } + break; + case DataType::Int16: + if ( elements > MAX_MEAN_ELEMENTS_INT16 ) + { + static const std::string constraint = fmt::format("max elements (int16) = {}", MAX_MEAN_ELEMENTS_INT16); + Failure(op, fmt::format("Too many reduced elements: {}", elements), constraint); + return false; + } + break; + default: + Failure(op, fmt::format("Unsupported Mean IFM type {}", DataTypeToString(ifmType))); + return false; + } + return true; +} + + void TfLiteSupportedOperators::Failure(const Operation *op, const std::string &message, const std::string &constraint) { assert(op); @@ -640,6 +753,7 @@ TfLiteSupportedOperators::TfLiteSupportedOperators(IArchitectureConstraints *con &TfLiteSupportedOperators::ConstraintTCShapes, &TfLiteSupportedOperators::ConstraintRsqrt, &TfLiteSupportedOperators::ConstraintConstParams, + &TfLiteSupportedOperators::ConstraintMean, }; } diff --git a/ethosu/regor/tflite/tflite_supported_operators.hpp b/ethosu/regor/tflite/tflite_supported_operators.hpp index 8473eb1c..06b5de79 100644 --- a/ethosu/regor/tflite/tflite_supported_operators.hpp +++ b/ethosu/regor/tflite/tflite_supported_operators.hpp @@ -69,5 +69,6 @@ private: bool ConstraintTCShapes(const Operation *op); bool ConstraintRsqrt(const Operation *op); bool ConstraintConstParams(const Operation *op); + bool ConstraintMean(const Operation *op); }; } // namespace regor -- GitLab