From 639f3b5555cb0ca4b3b58f84cefe2eef6f41fe9c Mon Sep 17 00:00:00 2001 From: Alexander Bengtsson Date: Mon, 7 Apr 2025 08:13:11 +0200 Subject: [PATCH] MLBEDSW-10230: Add supported-ops checks for StridedSlice - Stride-tensor must be constant - Begin/end must be constant - Begin must be less than end - Stride/begin/end must be decoded like in RewriteStridedSlice The resulting stride must be exclusively over H/W. The resulting stride must not be negative. Change-Id: I5b5fefb061b8b23fe74513c87e69c21d8597a06e Signed-off-by: Alexander Bengtsson --- .../regor/compiler/tflite_graph_optimiser.cpp | 8 -- .../tflite/tflite_supported_operators.cpp | 113 ++++++++++++++++++ .../tflite/tflite_supported_operators.hpp | 1 + 3 files changed, 114 insertions(+), 8 deletions(-) diff --git a/ethosu/regor/compiler/tflite_graph_optimiser.cpp b/ethosu/regor/compiler/tflite_graph_optimiser.cpp index 76d85df4..de3d6364 100644 --- a/ethosu/regor/compiler/tflite_graph_optimiser.cpp +++ b/ethosu/regor/compiler/tflite_graph_optimiser.cpp @@ -733,14 +733,6 @@ Operation *TFLiteGraphOptimiser::RewriteStridedSlice(Graph *const graph, Operati } } - // TODO MLBEDSW-10165: Handle stride < 0 and other dimensions than H and W - if ( sliceStride.LessMask(sliceStride.WithZeros()) || - sliceStride.WithHeight(1).WithWidth(1) != Shape::PadAxes(sliceShape.WithOnes(), 3, 1) ) - { - returnOp->SetPassthroughOp(); - return returnOp; - } - // Create a new memory copy op assert(sliceOffset + sliceShape <= ifmConn->shape); assert(sliceOffset >= ifmConn->shape.WithZeros()); diff --git a/ethosu/regor/tflite/tflite_supported_operators.cpp b/ethosu/regor/tflite/tflite_supported_operators.cpp index 9e047932..90e7bd10 100644 --- a/ethosu/regor/tflite/tflite_supported_operators.cpp +++ b/ethosu/regor/tflite/tflite_supported_operators.cpp @@ -22,6 +22,7 @@ #include "common/logging.hpp" #include "compiler/op_type.hpp" +#include "compiler/operation_util.hpp" #include "tflite_supported_operators_u55.hpp" #include "tflite_supported_operators_u85.hpp" @@ -577,6 +578,7 @@ bool TfLiteSupportedOperators::ConstraintConstParams(const Operation *op) switch ( opType ) { case OpType::Slice: + case OpType::StridedSlice: case OpType::Mean: case OpType::Pad: case OpType::PadV2: @@ -765,6 +767,116 @@ bool TfLiteSupportedOperators::ConstraintTransposeDims(const Operation *op) return true; } +bool TfLiteSupportedOperators::ConstraintStridedSlice(const Operation *op) +{ + OpType opType = op->Type(); + if ( opType == OpType::StridedSlice ) + { + const auto *ifmConn = op->Input(TensorUsage::IFM); + const auto *ofmConn = op->Output(TensorUsage::OFM); + const auto *beginParmConn = op->Input(TensorUsage::Params0); + const auto *endParamConn = op->Input(TensorUsage::Params1); + const auto *stridesParamConn = op->Input(TensorUsage::Params2); + + // Read StridedSlice attributes + int32_t begin_mask = 0; + int32_t ellipsis_mask = 0; + int32_t end_mask = 0; + int32_t new_axis_mask = 0; + int32_t shrink_axis_mask = 0; + const tflite::Operator *const passthrough = static_cast(op->Passthrough()); + + if ( passthrough ) + { + const auto options = passthrough->builtin_options_as_StridedSliceOptions(); + if ( options ) + { + begin_mask = options->begin_mask(); + ellipsis_mask = options->ellipsis_mask(); + end_mask = options->end_mask(); + new_axis_mask = options->new_axis_mask(); + shrink_axis_mask = options->shrink_axis_mask(); + } + } + + const Shape beginAttr = TensorToShape(beginParmConn->tensor.get(), beginParmConn->shape.Elements()); + const Shape endAttr = TensorToShape(endParamConn->tensor.get(), endParamConn->shape.Elements()); + + const Shape stridesAttr = TensorToShape(stridesParamConn->tensor.get(), stridesParamConn->shape.Elements()); + const int specShapeSize = std::min({beginAttr.Size(), endAttr.Size(), stridesAttr.Size()}); + + // Start off with the full IFM + const int ifmShapeSize = ifmConn->shape.Size(); + Shape sliceOffset(nullptr, ifmShapeSize, 0); + Shape sliceShape(ifmConn->shape); + Shape sliceStride(nullptr, ifmShapeSize, 1); + + // Process each spec + for ( int specIndex = 0, ifmIndex = 0; specIndex < specShapeSize; specIndex++ ) + { + const bool isBegin = (begin_mask & (1 << specIndex)) != 0; + const bool isEllipsis = (ellipsis_mask & (1 << specIndex)) != 0; + const bool isEnd = (end_mask & (1 << specIndex)) != 0; + const bool isNewAxis = (new_axis_mask & (1 << specIndex)) != 0; + const bool isShrink = (shrink_axis_mask & (1 << specIndex)) != 0; + + if ( isEllipsis ) + { + // Skip to the end + ifmIndex = ifmShapeSize - (specShapeSize - specIndex - 1); + assert(ifmIndex >= 0); + assert(ifmIndex <= ifmShapeSize); + } + else + { + if ( !isBegin || isShrink ) + { + // Handle the begin value + int begin = beginAttr[specIndex]; + if ( begin < 0 ) begin = ifmConn->shape[ifmIndex] + begin; + begin = std::clamp(begin, 0, ifmConn->shape[ifmIndex] - 1); + sliceOffset[ifmIndex] = begin; + sliceShape[ifmIndex] = isShrink ? 1 : ifmConn->shape[ifmIndex] - begin; + } + + if ( !isEnd && !isShrink ) + { + // Handle the end value + int end = endAttr[specIndex]; + if ( end < 0 ) end = ifmConn->shape[ifmIndex] + end; + end = std::clamp(end, 1, ifmConn->shape[ifmIndex]); + assert(end > sliceOffset[ifmIndex]); + sliceShape[ifmIndex] = end - sliceOffset[ifmIndex]; + } + + // Handle the stride value + sliceStride[ifmIndex] = stridesAttr[specIndex]; + + // Go to next dimension + ifmIndex++; + } + } + + // TODO MLBEDSW-10165: Handle stride < 0 and other dimensions than H and W + if ( sliceStride.WithHeight(1).WithWidth(1).Elements64() != 1 ) + { + Failure(op, "StridedSlice with unsupported stride axis", "Stride must be over H or W"); + return false; + } + if ( sliceStride.LessMask(sliceStride.WithZeros()) ) + { + Failure(op, "StridedSlice with unsupported negative stride", "Negative stride is not supported"); + return false; + } + if ( !sliceShape.GreaterMask(sliceShape.WithZeros()) ) + { + Failure(op, fmt::format("StridedSlice with invalid sliceShape: {}", sliceShape.ToString()), "sliceShape must be a volume"); + return false; + } + } + return true; +} + void TfLiteSupportedOperators::Failure(const Operation *op, const std::string &message, const std::string &constraint) { assert(op); @@ -823,6 +935,7 @@ TfLiteSupportedOperators::TfLiteSupportedOperators(IArchitectureConstraints *con &TfLiteSupportedOperators::ConstraintSoftmax, &TfLiteSupportedOperators::ConstraintPad, &TfLiteSupportedOperators::ConstraintTransposeDims, + &TfLiteSupportedOperators::ConstraintStridedSlice, }; } diff --git a/ethosu/regor/tflite/tflite_supported_operators.hpp b/ethosu/regor/tflite/tflite_supported_operators.hpp index 24133cca..4b668abf 100644 --- a/ethosu/regor/tflite/tflite_supported_operators.hpp +++ b/ethosu/regor/tflite/tflite_supported_operators.hpp @@ -73,6 +73,7 @@ private: bool ConstraintSoftmax(const Operation *op); bool ConstraintPad(const Operation *op); bool ConstraintTransposeDims(const Operation *op); + bool ConstraintStridedSlice(const Operation *op); }; // Factory for supported-ops checkers -- GitLab