From 44e1f94ea4fb5b93ee5c977cda50b8a23f804fba Mon Sep 17 00:00:00 2001 From: Alexander Bengtsson Date: Sun, 6 Apr 2025 10:21:49 +0200 Subject: [PATCH] MLBEDSW-10494: Supported-op checks for patterns (2) - Refactor existing pattern-matching algorithms to use the supported-ops framework. Pattern-matching optimisations will run after reshape-removal, but before single-op checks and any other optimisations. - Remove BatchToSpaceND and SpaceToBatchND from supported-ops as supported variants are now expected to be pattern-matched before reaching the single-op checks. Change-Id: I89faf37838c603aefea843d57632c1bad837e46f Signed-off-by: Alexander Bengtsson --- .../regor/compiler/tflite_graph_optimiser.cpp | 35 +++++++++++++++---- .../regor/compiler/tflite_graph_optimiser.hpp | 20 +++++++---- .../tflite/tflite_supported_operators_u55.cpp | 2 -- .../tflite/tflite_supported_operators_u85.cpp | 2 -- 4 files changed, 41 insertions(+), 18 deletions(-) diff --git a/ethosu/regor/compiler/tflite_graph_optimiser.cpp b/ethosu/regor/compiler/tflite_graph_optimiser.cpp index e4393105..76d85df4 100644 --- a/ethosu/regor/compiler/tflite_graph_optimiser.cpp +++ b/ethosu/regor/compiler/tflite_graph_optimiser.cpp @@ -1598,6 +1598,7 @@ Operation *TFLiteGraphOptimiser::RewriteSquaredDifference(Graph *const, Operatio Operation *TFLiteGraphOptimiser::RewriteSpaceToBatchConvBatchToSpace(Graph *const, Operation *const operation) { auto opType = operation->Type(); + auto returnOp = operation; if ( opType == OpType::DepthwiseConv2D || opType == OpType::Conv2D ) { auto prevOp = operation->IFM(0)->Writers().empty() ? nullptr : operation->IFM(0)->Writers().front().get(); @@ -1608,9 +1609,18 @@ Operation *TFLiteGraphOptimiser::RewriteSpaceToBatchConvBatchToSpace(Graph *cons operation->OFM()->Readers().size() == 1 // No other consumers of BatchToSpaceND input ) { + auto newOp = std::make_shared(*operation); + for ( const auto &[usage, conn] : operation->Inputs().pairs() ) + { + newOp->CopyInput(usage, conn); + } + for ( const auto &[usage, conn] : operation->Outputs().pairs() ) + { + newOp->CopyOutput(usage, conn); + } // Go ahead and short-circuit the SpaceToBatchND and BatchToSpaceND ops - operation->ConnectInput(TensorUsage::IFM0, prevOp->Input(TensorUsage::IFM0)->tensor); - operation->ConnectOutput(TensorUsage::OFM, nextOp->Output(TensorUsage::OFM)->tensor); + newOp->ConnectInput(TensorUsage::IFM0, prevOp->Input(TensorUsage::IFM0)->tensor); + newOp->ConnectOutput(TensorUsage::OFM, nextOp->Output(TensorUsage::OFM)->tensor); // Set new kernel dilation auto blockShape = prevOp->Input(TensorUsage::Params); int count = blockShape->shape[0]; @@ -1627,13 +1637,24 @@ Operation *TFLiteGraphOptimiser::RewriteSpaceToBatchConvBatchToSpace(Graph *cons int ypad = NeededTotalPadding(inputShape.Height(), stride.y, dilatedWH.y); Margin pad = Margin(ypad / 2, xpad / 2, (ypad + 1) / 2, (xpad + 1) / 2); // Set the new kernel with updated dilation and padding - operation->SetKernel(std::make_unique(dilatedKernel.WithPadding(pad))); - // Disconnect the SpaceToBatchND and BatchToSpaceND ops - prevOp->Disconnect(); - nextOp->Disconnect(); + newOp->SetKernel(std::make_unique(dilatedKernel.WithPadding(pad))); + + // Validate that the pattern-matching is supported + if ( _supportedOps->Check(newOp.get()) ) + { + returnOp = newOp.get(); + // Disconnect matched pattern + prevOp->Disconnect(); + nextOp->Disconnect(); + operation->Disconnect(); + } + else + { + newOp->Disconnect(); + } } } - return operation; + return returnOp; } // Fixup Conv2D and DepthwiseConv2D to allow dilation greater than 2. diff --git a/ethosu/regor/compiler/tflite_graph_optimiser.hpp b/ethosu/regor/compiler/tflite_graph_optimiser.hpp index ffa5af53..77b3ef7e 100644 --- a/ethosu/regor/compiler/tflite_graph_optimiser.hpp +++ b/ethosu/regor/compiler/tflite_graph_optimiser.hpp @@ -197,6 +197,19 @@ public: #endif } }, + { + {}, + { + // prerequisite to pattern-matching + &TFLiteGraphOptimiser::RemoveReshape, + // pattern-matching functions + // (must run before supported-operator checks) + // Every pattern-matching function is responsible of calling + // _supportedOperators->Check(newOp) + // before replacing a pattern with newOp + &TFLiteGraphOptimiser::RewriteSpaceToBatchConvBatchToSpace, + } + }, { {}, { @@ -229,13 +242,6 @@ public: { {}, { - &TFLiteGraphOptimiser::RemoveReshape, - } - }, - { - {}, - { - &TFLiteGraphOptimiser::RewriteSpaceToBatchConvBatchToSpace, &TFLiteGraphOptimiser::FixupDilationGT2, &TFLiteGraphOptimiser::FixupBias, &TFLiteGraphOptimiser::ConvertReduceMinMaxAnyAll, diff --git a/ethosu/regor/tflite/tflite_supported_operators_u55.cpp b/ethosu/regor/tflite/tflite_supported_operators_u55.cpp index 36cce2b5..c0cfa418 100644 --- a/ethosu/regor/tflite/tflite_supported_operators_u55.cpp +++ b/ethosu/regor/tflite/tflite_supported_operators_u55.cpp @@ -50,8 +50,6 @@ TfLiteSupportedOperatorsU55::TfLiteSupportedOperatorsU55(IArchitectureConstraint OpType::Softmax, OpType::Tanh, OpType::Pad, - OpType::BatchToSpaceND, - OpType::SpaceToBatchND, OpType::Transpose, OpType::Mean, OpType::Sub, diff --git a/ethosu/regor/tflite/tflite_supported_operators_u85.cpp b/ethosu/regor/tflite/tflite_supported_operators_u85.cpp index 5cb7a0aa..67e89983 100644 --- a/ethosu/regor/tflite/tflite_supported_operators_u85.cpp +++ b/ethosu/regor/tflite/tflite_supported_operators_u85.cpp @@ -93,8 +93,6 @@ TfLiteSupportedOperatorsU85::TfLiteSupportedOperatorsU85(IArchitectureConstraint OpType::SplitV, OpType::ReverseV2, OpType::GatherNd, - OpType::SpaceToBatchND, - OpType::BatchToSpaceND, OpType::Quantize, OpType::HardSwish, OpType::SelectV2, -- GitLab