diff --git a/ethosu/regor/compiler/tflite_graph_optimiser.cpp b/ethosu/regor/compiler/tflite_graph_optimiser.cpp index b303615fd354e03cc922d405697d772020736ef8..95a35cd86ae6eb95626fa78306084850516ed25c 100644 --- a/ethosu/regor/compiler/tflite_graph_optimiser.cpp +++ b/ethosu/regor/compiler/tflite_graph_optimiser.cpp @@ -968,7 +968,7 @@ Operation *TFLiteGraphOptimiser::ConvertGather(Graph *const graph, Operation *co return returnOp; } -// Replace TFLite ScatterNd with GraphIR Scatter, if possible. +// Replace TFLite ScatterNd with GraphIR Scatter Operation *TFLiteGraphOptimiser::ConvertScatter(Graph *const graph, Operation *const operation) { UNUSED(graph); @@ -987,33 +987,10 @@ Operation *TFLiteGraphOptimiser::ConvertScatter(Graph *const graph, Operation *c assert(updatesConn); assert(shapeConn); assert(ofmConn); - - // Can only support this op when last dimension is 1 - if ( idxConn->shape[-1] != 1 ) - { - return returnOp; - } - - // Can only support constant index tensors - if ( !idxConn->tensor->IsConstant() ) - { - return operation; - } - // Can not support duplicates in the index tensor - std::unordered_set unique_idxs; - for ( int i = 0; i < idxConn->tensor->View().Elements(); i++ ) - { - int idx = idxConn->tensor->View().RawData()[i]; - if ( unique_idxs.find(idx) != unique_idxs.end() ) - { - return operation; - } - unique_idxs.insert(idx); - } - assert(shapeConn->tensor->IsConstant()); assert(shapeConn->shape.Size() == 1); - + assert(idxConn->shape[-1] == 1); + assert(idxConn->tensor->IsConstant()); // Calculate GraphIR Scatter N dim int N = 1; diff --git a/ethosu/regor/tflite/tflite_supported_operators_u85.cpp b/ethosu/regor/tflite/tflite_supported_operators_u85.cpp index f4c8c9ad20d8a614971b92a8baf054637734218a..fddef99a48383d34940be1197d3ae3892d76c305 100644 --- a/ethosu/regor/tflite/tflite_supported_operators_u85.cpp +++ b/ethosu/regor/tflite/tflite_supported_operators_u85.cpp @@ -120,6 +120,7 @@ TfLiteSupportedOperatorsU85::TfLiteSupportedOperatorsU85(IArchitectureConstraint &TfLiteSupportedOperatorsU85::ConstraintResizeCommon, &TfLiteSupportedOperatorsU85::ConstraintResizeBilinear, &TfLiteSupportedOperatorsU85::ConstraintGather, + &TfLiteSupportedOperatorsU85::ConstraintScatter, }; } @@ -328,4 +329,44 @@ bool TfLiteSupportedOperatorsU85::ConstraintGather(const Operation *op) } return true; } + +bool TfLiteSupportedOperatorsU85::ConstraintScatter(const Operation *op) +{ + OpType opType = op->Type(); + if ( opType != OpType::ScatterNd ) + { + return true; + } + auto *idxConn = op->Input(TensorUsage::IFM0); + auto *shapeConn = op->Input(TensorUsage::Params); + assert(idxConn); + assert(shapeConn); + // index tensor must have C == 1 + if ( idxConn->shape[-1] != 1 ) + { + Failure(op, fmt::format("index shape: {}", idxConn->shape.ToString()), "Channel must be 1 for ScatterNd index tensor"); + return false; + } + // index tensor must be constant + if ( !idxConn->tensor->IsConstant() ) + { + Failure(op, "non-constant index tensor", "index tensor must be constant"); + return false; + } + // shape tensor must be constant + if ( !shapeConn->tensor->IsConstant() ) + { + Failure(op, "non-constant shape tensor", "shape tensor must be constant"); + return false; + } + // Can not support duplicates in the index tensor + const auto idxs = idxConn->tensor->View().Values(); + const std::unordered_set uniqueIdxs(idxs.begin(), idxs.end()); + if ( idxConn->tensor->View().Elements() != int(uniqueIdxs.size()) ) + { + Failure(op, "index tensor contains duplicates", "index tensor elements must be unique"); + return false; + } + return true; +} } // namespace regor diff --git a/ethosu/regor/tflite/tflite_supported_operators_u85.hpp b/ethosu/regor/tflite/tflite_supported_operators_u85.hpp index 8aab9e66f9294f71aafd77f2dbcdae07e03d9225..c6be33fcb1fcdf45114edef1a8fd7233633600b5 100644 --- a/ethosu/regor/tflite/tflite_supported_operators_u85.hpp +++ b/ethosu/regor/tflite/tflite_supported_operators_u85.hpp @@ -42,5 +42,6 @@ private: bool ConstraintResizeCommon(const Operation *op); bool ConstraintResizeBilinear(const Operation *op); bool ConstraintGather(const Operation *op); + bool ConstraintScatter(const Operation *op); }; } // namespace regor