From 189e68e7628b0f3d76230e2a7a6302009175a72a Mon Sep 17 00:00:00 2001 From: Johan Gunnarsson Date: Wed, 19 Feb 2025 11:30:11 +0100 Subject: [PATCH] MLBEDSW-9069: Handle CPU ops in RemoveReshape Signed-off-by: Johan Gunnarsson Change-Id: I95e9e2be8b38b65c41953603ba29ef82805852e6 --- ethosu/regor/compiler/tflite_graph_optimiser.cpp | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/ethosu/regor/compiler/tflite_graph_optimiser.cpp b/ethosu/regor/compiler/tflite_graph_optimiser.cpp index 744c272a..2a169856 100644 --- a/ethosu/regor/compiler/tflite_graph_optimiser.cpp +++ b/ethosu/regor/compiler/tflite_graph_optimiser.cpp @@ -823,15 +823,21 @@ Operation *TFLiteGraphOptimiser::RemoveReshape(Graph *const graph, Operation *co auto *ifm = ifmConn->tensor.get(); auto *ofm = ofmConn->tensor.get(); - // Check if ifm/ofm are network ifm/ofm + // Check if ifm/ofm are network ifm/ofm or constant + bool isIfmConst = ifm->IsConstant(); bool isIfmSgIfm = IsTensorInVector(graph->Inputs(), ifm); bool isOfmSgOfm = IsTensorInVector(graph->Outputs(), ofm); bool isIfmSgOfm = IsTensorInVector(graph->Outputs(), ifm); - // TODO: MLBEDSW-9069: Check CPU operator producer/consumer + // Check if ifm/ofm is produced/consumed by a CPU operation + auto isPassthroughOp = [](const std::shared_ptr &op) { return op->Type() == OpType::Passthrough; }; + const bool isOfmCpuIfm = + std::find_if(ofm->Readers().begin(), ofm->Readers().end(), isPassthroughOp) != ofm->Readers().end(); + const bool isIfmCpuOfm = + std::find_if(ifm->Writers().begin(), ifm->Writers().end(), isPassthroughOp) != ifm->Writers().end(); // Inserts a copy op if needed before removing reshapes. - if ( (isIfmSgIfm || isIfmSgOfm) && (isOfmSgOfm) ) + if ( ((isIfmSgIfm || isIfmSgOfm || isIfmConst) && (isOfmSgOfm)) || (isIfmCpuOfm && isOfmCpuIfm) ) { auto copyOp = InsertCopyOpAfterTensor(ifmConn->tensor, ifmConn->quantization); copyOp->Output(TensorUsage::OFM)->Set(RoundMode::NATURAL); @@ -845,10 +851,8 @@ Operation *TFLiteGraphOptimiser::RemoveReshape(Graph *const graph, Operation *co } // Remove the reshape and one of the tensors. - if ( isOfmSgOfm ) + if ( isOfmSgOfm || isOfmCpuIfm ) { - // TODO: This path should also be used for ofm tensors consumed by CPU ops. - // The OFM is in graph outputs, do not remove this tensor. // Bypass by replacing ifm with ofm. // Set OFM as output for IFM producers -- GitLab