diff --git a/ethosu/regor/compiler/tflite_graph_optimiser.cpp b/ethosu/regor/compiler/tflite_graph_optimiser.cpp index 92f10a8816919683b4b3e6eeb640caaaf7b57e96..b303615fd354e03cc922d405697d772020736ef8 100644 --- a/ethosu/regor/compiler/tflite_graph_optimiser.cpp +++ b/ethosu/regor/compiler/tflite_graph_optimiser.cpp @@ -857,7 +857,7 @@ Operation *TFLiteGraphOptimiser::ConvertReverse(Graph *const graph, Operation *c return returnOp; } -// Replace TFLite GatherV2 and GatherNd with GraphIR Gather, if possible. +// Replace TFLite GatherV2 with GraphIR Gather Operation *TFLiteGraphOptimiser::ConvertGather(Graph *const graph, Operation *const operation) { UNUSED(graph); @@ -899,12 +899,8 @@ Operation *TFLiteGraphOptimiser::ConvertGather(Graph *const graph, Operation *co assert(batchDimsParam <= axisParam); } } - // TODO: Convert to supported ops check // TODO: MLBEDSW-10279 Investigate if constraint can be relaxed - if ( axisParam != batchDimsParam ) - { - return returnOp; - } + assert(axisParam == batchDimsParam); // Calculate GraphIR Gather N dim int N = 1; diff --git a/ethosu/regor/tflite/tflite_supported_operators_u85.cpp b/ethosu/regor/tflite/tflite_supported_operators_u85.cpp index 79f9d44ddc2e86b51cc07436b740523fe112ffb7..f4c8c9ad20d8a614971b92a8baf054637734218a 100644 --- a/ethosu/regor/tflite/tflite_supported_operators_u85.cpp +++ b/ethosu/regor/tflite/tflite_supported_operators_u85.cpp @@ -92,7 +92,6 @@ TfLiteSupportedOperatorsU85::TfLiteSupportedOperatorsU85(IArchitectureConstraint OpType::Abs, OpType::SplitV, OpType::ReverseV2, - OpType::GatherNd, OpType::Quantize, OpType::HardSwish, OpType::SelectV2, @@ -120,6 +119,7 @@ TfLiteSupportedOperatorsU85::TfLiteSupportedOperatorsU85(IArchitectureConstraint _checks = { &TfLiteSupportedOperatorsU85::ConstraintResizeCommon, &TfLiteSupportedOperatorsU85::ConstraintResizeBilinear, + &TfLiteSupportedOperatorsU85::ConstraintGather, }; } @@ -299,4 +299,33 @@ bool TfLiteSupportedOperatorsU85::ConstraintResizeBilinear(const Operation *op) } return true; } + +bool TfLiteSupportedOperatorsU85::ConstraintGather(const Operation *op) +{ + OpType opType = op->Type(); + if ( opType != OpType::GatherV2 ) + { + return true; + } + const tflite::Operator *const passthrough = static_cast(op->Passthrough()); + const auto options = passthrough->builtin_options_as_GatherOptions(); + auto *params = op->Input(TensorUsage::IFM0); + assert(params); + int paramsRank = params->shape.Size(); + int batchDimsParam = 0; + int axisParam = 0; + if ( options ) + { + axisParam = options->axis(); + if ( axisParam < 0 ) axisParam = paramsRank - (-axisParam); + batchDimsParam = options->batch_dims(); + } + + if ( axisParam != batchDimsParam ) + { + Failure(op, fmt::format("axis: {} != batch_dims: {}", axisParam, batchDimsParam), "axis must be equal to batch_dims"); + return false; + } + return true; +} } // namespace regor diff --git a/ethosu/regor/tflite/tflite_supported_operators_u85.hpp b/ethosu/regor/tflite/tflite_supported_operators_u85.hpp index 5a76570e7141e58e25e327e73fa588d1d7811389..8aab9e66f9294f71aafd77f2dbcdae07e03d9225 100644 --- a/ethosu/regor/tflite/tflite_supported_operators_u85.hpp +++ b/ethosu/regor/tflite/tflite_supported_operators_u85.hpp @@ -41,5 +41,6 @@ public: private: bool ConstraintResizeCommon(const Operation *op); bool ConstraintResizeBilinear(const Operation *op); + bool ConstraintGather(const Operation *op); }; } // namespace regor