From 9d7a89f42500c697bdb9ea64104c251d2c160353 Mon Sep 17 00:00:00 2001 From: Rickard Bolin Date: Wed, 28 Aug 2024 11:00:15 +0000 Subject: [PATCH] MLBEDSW-8459: Add GATHER and SCATTER support checks Can only support constant index tensors where no indexes are duplicated. Change-Id: Iddf44bef8f0c6aca3bef1339aebea60507077540 Signed-off-by: Rickard Bolin --- .../regor/compiler/tflite_graph_optimiser.cpp | 25 ++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/ethosu/regor/compiler/tflite_graph_optimiser.cpp b/ethosu/regor/compiler/tflite_graph_optimiser.cpp index c88282ed..571aee37 100644 --- a/ethosu/regor/compiler/tflite_graph_optimiser.cpp +++ b/ethosu/regor/compiler/tflite_graph_optimiser.cpp @@ -42,6 +42,7 @@ #include #include #include +#include #include namespace regor @@ -974,6 +975,12 @@ Operation *TFLiteGraphOptimiser::ConvertGather(Graph *const graph, Operation *co assert(batchDimsParam <= axisParam); } } + // TODO: Convert to supported ops check + // TODO: MLBEDSW-10279 Investigate if constraint can be relaxed + if ( axisParam != batchDimsParam ) + { + return returnOp; + } // Calculate GraphIR Gather N dim int N = 1; @@ -1070,7 +1077,23 @@ Operation *TFLiteGraphOptimiser::ConvertScatter(Graph *const graph, Operation *c return returnOp; } - // TODO: MLBEDSW-8459: Add supported ops check for TFLite ScatterND + // Can only support constant index tensors + if ( !idxConn->tensor->IsConstant() ) + { + return operation; + } + // Can not support duplicates in the index tensor + std::unordered_set unique_idxs; + for ( int i = 0; i < idxConn->tensor->View().Elements(); i++ ) + { + int idx = idxConn->tensor->View().RawData()[i]; + if ( unique_idxs.find(idx) != unique_idxs.end() ) + { + return operation; + } + unique_idxs.insert(idx); + } + assert(shapeConn->tensor->IsConstant()); assert(shapeConn->shape.Size() == 1); -- GitLab