diff --git a/ethosu/regor/compiler/tflite_graph_optimiser.cpp b/ethosu/regor/compiler/tflite_graph_optimiser.cpp index c88282ed0f73b8b74df6e27fa8d4661a6e54fd78..571aee379af8982b8ecdf4fb87101da1b6edb0f0 100644 --- a/ethosu/regor/compiler/tflite_graph_optimiser.cpp +++ b/ethosu/regor/compiler/tflite_graph_optimiser.cpp @@ -42,6 +42,7 @@ #include #include #include +#include #include namespace regor @@ -974,6 +975,12 @@ Operation *TFLiteGraphOptimiser::ConvertGather(Graph *const graph, Operation *co assert(batchDimsParam <= axisParam); } } + // TODO: Convert to supported ops check + // TODO: MLBEDSW-10279 Investigate if constraint can be relaxed + if ( axisParam != batchDimsParam ) + { + return returnOp; + } // Calculate GraphIR Gather N dim int N = 1; @@ -1070,7 +1077,23 @@ Operation *TFLiteGraphOptimiser::ConvertScatter(Graph *const graph, Operation *c return returnOp; } - // TODO: MLBEDSW-8459: Add supported ops check for TFLite ScatterND + // Can only support constant index tensors + if ( !idxConn->tensor->IsConstant() ) + { + return operation; + } + // Can not support duplicates in the index tensor + std::unordered_set unique_idxs; + for ( int i = 0; i < idxConn->tensor->View().Elements(); i++ ) + { + int idx = idxConn->tensor->View().RawData()[i]; + if ( unique_idxs.find(idx) != unique_idxs.end() ) + { + return operation; + } + unique_idxs.insert(idx); + } + assert(shapeConn->tensor->IsConstant()); assert(shapeConn->shape.Size() == 1);