From a9f3cd67c0e41bca78c0f6fdf3890e2bb0c7bf27 Mon Sep 17 00:00:00 2001 From: Johan Gunnarsson Date: Wed, 16 Apr 2025 15:57:30 +0200 Subject: [PATCH] MLBEDSW-10471: Deduplicate RemoveReshape RemoveReshape is duplicated (and slightly diverging) in TFLiteGraphOptimiser and GraphIrOptimiser. This patch merges them into one and puts the result into the parent GraphOptimiser class. Signed-off-by: Johan Gunnarsson Change-Id: I1153f21d3a2f3fbdcb71ba2a27e3bd606bfe3ae4 --- ethosu/regor/compiler/graph_optimiser.cpp | 74 +++++++++++++++++++ ethosu/regor/compiler/graph_optimiser.hpp | 1 + ethosu/regor/compiler/graphir_optimiser.cpp | 65 ---------------- ethosu/regor/compiler/graphir_optimiser.hpp | 3 +- .../regor/compiler/tflite_graph_optimiser.cpp | 65 ---------------- .../regor/compiler/tflite_graph_optimiser.hpp | 3 +- 6 files changed, 77 insertions(+), 134 deletions(-) diff --git a/ethosu/regor/compiler/graph_optimiser.cpp b/ethosu/regor/compiler/graph_optimiser.cpp index 9ae4720d..0f7b0789 100644 --- a/ethosu/regor/compiler/graph_optimiser.cpp +++ b/ethosu/regor/compiler/graph_optimiser.cpp @@ -25,6 +25,7 @@ #include "graphir_optimiser.hpp" #include "op_type.hpp" #include "operation.hpp" +#include "optimiser_utils.hpp" #include "tensor.hpp" #include "tflite/tflite_supported_operators.hpp" #include "tflite_graph_optimiser.hpp" @@ -42,6 +43,8 @@ namespace regor { +using namespace GraphOptimisation; + std::unique_ptr GraphOptimiser::MakeGraphOptimiser( GraphNotation notation, Architecture *arch, const GraphOptimiserOptions &options, OptimiserDatabase *db) { @@ -132,6 +135,77 @@ Operation *GraphOptimiser::RecordOptimisation(Graph *const graph, Operation *con return operation; } +Operation *GraphOptimiser::RemoveReshape(Graph *const graph, Operation *const operation) +{ + Operation *returnOp = operation; + const OpType opType = operation->Type(); + if ( IsReshape(opType) ) + { + auto *ifmConn = operation->Input(TensorUsage::IFM0); + auto *ofmConn = operation->Output(TensorUsage::OFM); + auto *ifm = ifmConn->tensor.get(); + auto *ofm = ofmConn->tensor.get(); + + // Check if ifm/ofm are network ifm/ofm or constant + bool isIfmConst = ifm->IsConstant(); + bool isIfmSgIfm = IsTensorInVector(graph->Inputs(), ifm); + bool isOfmSgOfm = IsTensorInVector(graph->Outputs(), ofm); + bool isIfmSgOfm = IsTensorInVector(graph->Outputs(), ifm); + + // Check if ifm/ofm is produced/consumed by a CPU operation + auto isPassthroughOp = [](const std::shared_ptr &op) { return op->Type() == OpType::Passthrough; }; + const bool isOfmCpuIfm = + std::find_if(ofm->Readers().begin(), ofm->Readers().end(), isPassthroughOp) != ofm->Readers().end(); + const bool isIfmCpuOfm = + std::find_if(ifm->Writers().begin(), ifm->Writers().end(), isPassthroughOp) != ifm->Writers().end(); + + // Inserts a copy op if needed before removing reshapes. + if ( ((isIfmSgIfm || isIfmSgOfm || isIfmConst || isIfmCpuOfm) && (isOfmSgOfm || isOfmCpuIfm)) || + ((ifm->Readers().size() > 1) && (ifm->StorageShape() != ofm->StorageShape() || ifm->AxisOrder() != ofm->AxisOrder())) ) + { + auto copyOp = InsertCopyOpAfterTensor(ifmConn->tensor, ifmConn->quantization); + copyOp->Output(TensorUsage::OFM)->Set(RoundMode::NATURAL); + + // reset the ifm to reflect the reshape's new ifm + ifmConn = operation->Input(TensorUsage::IFM0); + ifm = ifmConn->tensor.get(); + returnOp = copyOp.get(); + RecordOptimisation(operation, returnOp); + // Reshape still needs to be removed. + } + + // Remove the reshape and one of the tensors. + if ( isOfmSgOfm || isOfmCpuIfm ) + { + // The OFM is in graph outputs, do not remove this tensor. + // Bypass by replacing ifm with ofm. + // Set OFM as output for IFM producers + ReplaceProducerOutput(ifm->Writers(), ifm, ofmConn->tensor); + + // Set OFM as input to other IFM consumers. + ReplaceConsumerInput(operation, ifm->Readers(), ifm, ofmConn->tensor); + } + else + { + // Bypass by replacing ofm with ifm. + // Set IFM as input to OFM consumers. + ReplaceConsumerInput(nullptr, ofm->Readers(), ofm, ifmConn->tensor); + assert(ifm->AxisOrder() == AxisOrder::Unknown || ifm->AxisOrder() == ofm->AxisOrder()); + + // This is needed as we use the weight tensor, and not the tensor connection, + // during weight encode. MLBEDSW-9267 + ifmConn->tensor->SetAxisOrder(ofm->AxisOrder()); + ifmConn->tensor->Reshape(ofm->StorageShape()); + } + // Remove the reshape from ifm readers and ofm writers. + // Note the Inputs/Outputs on operation should still be intact to not break the traversal. + ifm->RemoveReader(operation->shared_from_this()); + ofm->RemoveWriter(operation->shared_from_this()); + } + + return returnOp; +} + void GraphOptimiser::RecordOptimisation(const Operation *operation, const Operation *op) { if ( _db ) diff --git a/ethosu/regor/compiler/graph_optimiser.hpp b/ethosu/regor/compiler/graph_optimiser.hpp index 3fbd7cec..0d844cf7 100644 --- a/ethosu/regor/compiler/graph_optimiser.hpp +++ b/ethosu/regor/compiler/graph_optimiser.hpp @@ -246,6 +246,7 @@ public: #endif Operation *RecordOperation(Graph *const graph, Operation *const operation); Operation *RecordOptimisation(Graph *const graph, Operation *const operation); + Operation *RemoveReshape(Graph *const graph, Operation *const operation); void RecordOptimisation(const Operation *operation, const Operation *op); void PrintGraph(const Graph *graph, const std::string &label) const; void PrintQuantization(const Graph *graph, const std::string &label) const; diff --git a/ethosu/regor/compiler/graphir_optimiser.cpp b/ethosu/regor/compiler/graphir_optimiser.cpp index 5d30cbbe..251b0536 100644 --- a/ethosu/regor/compiler/graphir_optimiser.cpp +++ b/ethosu/regor/compiler/graphir_optimiser.cpp @@ -371,71 +371,6 @@ Operation *GraphIrOptimiser::ConstPropagation(Graph *const graph, Operation *con return operation; } -Operation *GraphIrOptimiser::RemoveReshape(Graph *const graph, Operation *const operation) -{ - Operation *returnOp = operation; - OpType opType = operation->Type(); - - if ( IsReshape(opType) ) - { - auto *ifmConn = operation->Input(TensorUsage::IFM0); - auto *ofmConn = operation->Output(TensorUsage::OFM); - auto *ifm = ifmConn->tensor.get(); - auto *ofm = ofmConn->tensor.get(); - - // Check if ifm/ofm are network ifm/ofm or constant - bool isIfmConst = ifm->IsConstant(); - bool isIfmSgIfm = IsTensorInVector(graph->Inputs(), ifm); - bool isOfmSgOfm = IsTensorInVector(graph->Outputs(), ofm); - bool isIfmSgOfm = IsTensorInVector(graph->Outputs(), ifm); - - // TODO: MLBEDSW-9069: Check CPU operator producer/consumer - // Inserts a copy op if needed before removing reshapes. - if ( ((isIfmSgIfm || isIfmSgOfm || isIfmConst) && (isOfmSgOfm)) || - ((ifm->Readers().size() > 1) && (ifm->StorageShape() != ofm->StorageShape() || ifm->AxisOrder() != ofm->AxisOrder())) ) - { - auto copyOp = InsertCopyOpAfterTensor(ifmConn->tensor, ifmConn->quantization); - // reset the ifm to reflect the reshape's new ifm - ifmConn = operation->Input(TensorUsage::IFM0); - ifm = ifmConn->tensor.get(); - returnOp = copyOp.get(); - RecordOptimisation(operation, returnOp); - // Reshape still needs to be removed. - } - - // Remove the reshape and one of the tensors. - if ( isOfmSgOfm ) - { - // TODO: This path should also be used for ofm tensors consumed by CPU ops. - - // The OFM is in graph outputs, do not remove this tensor. - // Bypass by replacing ifm with ofm. - // Set OFM as output for IFM producers - ReplaceProducerOutput(ifm->Writers(), ifm, ofmConn->tensor); - - // Set OFM as input to other IFM consumers. - ReplaceConsumerInput(operation, ifm->Readers(), ifm, ofmConn->tensor); - } - else - { - // Bypass by replacing ofm with ifm. - // Set IFM as input to OFM consumers. - ReplaceConsumerInput(nullptr, ofm->Readers(), ofm, ifmConn->tensor); - assert(ifm->AxisOrder() == AxisOrder::Unknown || ifm->AxisOrder() == ofm->AxisOrder()); - // This is needed as we use the weight tensor, and not the tensor connection, - // during weight encode. MLBEDSW-9267 - ifmConn->tensor->SetAxisOrder(ofm->AxisOrder()); - ifmConn->tensor->Reshape(ofm->StorageShape()); - } - // Remove the reshape from ifm readers and ofm writers. - // Note the Inputs/Outputs on operation should still be intact to not break the traversal. - ifm->RemoveReader(operation->shared_from_this()); - ofm->RemoveWriter(operation->shared_from_this()); - } - - return returnOp; -} - Operation *GraphIrOptimiser::RewriteFullyConnected(Graph *const graph, Operation *const operation) { UNUSED(graph); diff --git a/ethosu/regor/compiler/graphir_optimiser.hpp b/ethosu/regor/compiler/graphir_optimiser.hpp index 0294ed01..7133f0ac 100644 --- a/ethosu/regor/compiler/graphir_optimiser.hpp +++ b/ethosu/regor/compiler/graphir_optimiser.hpp @@ -41,7 +41,6 @@ class GraphIrOptimiser : public GraphOptimiser private: Operation *ConstPropagation(Graph *const graph, Operation *const operation); - Operation *RemoveReshape(Graph *const graph, Operation *const operation); Operation *ConvertAttributes(Graph *const graph, Operation *const operation); Operation *ConvertResizeOffsets(Graph *const graph, Operation *const operation); Tensor *ConvertInt48Tensors(Graph *graph, Tensor *tensor); @@ -102,7 +101,7 @@ private: &GraphIrOptimiser::ConvertInt4Tensors, }, { - &GraphIrOptimiser::RemoveReshape, + &GraphOptimiser::RemoveReshape, } }, { diff --git a/ethosu/regor/compiler/tflite_graph_optimiser.cpp b/ethosu/regor/compiler/tflite_graph_optimiser.cpp index 1d71c709..92f10a88 100644 --- a/ethosu/regor/compiler/tflite_graph_optimiser.cpp +++ b/ethosu/regor/compiler/tflite_graph_optimiser.cpp @@ -799,71 +799,6 @@ Operation *TFLiteGraphOptimiser::RewriteUnpack(Graph *const graph, Operation *co } -Operation *TFLiteGraphOptimiser::RemoveReshape(Graph *const graph, Operation *const operation) -{ - Operation *returnOp = operation; - OpType opType = operation->Type(); - - if ( IsReshape(opType) ) - { - auto *ifmConn = operation->Input(TensorUsage::IFM0); - auto *ofmConn = operation->Output(TensorUsage::OFM); - auto *ifm = ifmConn->tensor.get(); - auto *ofm = ofmConn->tensor.get(); - - // Check if ifm/ofm are network ifm/ofm or constant - bool isIfmConst = ifm->IsConstant(); - bool isIfmSgIfm = IsTensorInVector(graph->Inputs(), ifm); - bool isOfmSgOfm = IsTensorInVector(graph->Outputs(), ofm); - bool isIfmSgOfm = IsTensorInVector(graph->Outputs(), ifm); - - // Check if ifm/ofm is produced/consumed by a CPU operation - auto isPassthroughOp = [](const std::shared_ptr &op) { return op->Type() == OpType::Passthrough; }; - const bool isOfmCpuIfm = - std::find_if(ofm->Readers().begin(), ofm->Readers().end(), isPassthroughOp) != ofm->Readers().end(); - const bool isIfmCpuOfm = - std::find_if(ifm->Writers().begin(), ifm->Writers().end(), isPassthroughOp) != ifm->Writers().end(); - - // Inserts a copy op if needed before removing reshapes. - if ( ((isIfmSgIfm || isIfmSgOfm || isIfmConst || isIfmCpuOfm) && (isOfmSgOfm || isOfmCpuIfm)) ) - { - auto copyOp = InsertCopyOpAfterTensor(ifmConn->tensor, ifmConn->quantization); - copyOp->Output(TensorUsage::OFM)->Set(RoundMode::NATURAL); - - // reset the ifm to reflect the reshape's new ifm - ifmConn = operation->Input(TensorUsage::IFM0); - ifm = ifmConn->tensor.get(); - returnOp = copyOp.get(); - RecordOptimisation(operation, returnOp); - // Reshape still needs to be removed. - } - - // Remove the reshape and one of the tensors. - if ( isOfmSgOfm || isOfmCpuIfm ) - { - // The OFM is in graph outputs, do not remove this tensor. - // Bypass by replacing ifm with ofm. - // Set OFM as output for IFM producers - ReplaceProducerOutput(ifm->Writers(), ifm, ofmConn->tensor); - - // Set OFM as input to other IFM consumers. - ReplaceConsumerInput(operation, ifm->Readers(), ifm, ofmConn->tensor); - } - else - { - // Bypass by replacing ofm with ifm. - // Set IFM as input to OFM consumers. - ReplaceConsumerInput(nullptr, ofm->Readers(), ofm, ifmConn->tensor); - } - // Remove the reshape from ifm readers and ofm writers. - // Note the Inputs/Outputs on operation should still be intact to not break the traversal. - ifm->RemoveReader(operation->shared_from_this()); - ofm->RemoveWriter(operation->shared_from_this()); - } - - return returnOp; -} - // Convert ReverseV2 into TOSA Reverse // ReverseV2 supports a vector of axes, while TOSA reverse only supports one axis // If there is more than one reversed axis, convert to a sequence of Reverse operations. diff --git a/ethosu/regor/compiler/tflite_graph_optimiser.hpp b/ethosu/regor/compiler/tflite_graph_optimiser.hpp index c5b8b473..fd77451f 100644 --- a/ethosu/regor/compiler/tflite_graph_optimiser.hpp +++ b/ethosu/regor/compiler/tflite_graph_optimiser.hpp @@ -105,7 +105,6 @@ private: Operation *RewriteUnpack(Graph *const graph, Operation *const operation); Operation *RewriteSlice(Graph *const graph, Operation *const operation); Operation *RewriteStridedSlice(Graph *const graph, Operation *const operation); - Operation *RemoveReshape(Graph *const graph, Operation *const operation); Operation *ConvertReverse(Graph *const graph, Operation *const operation); Operation *ConvertGather(Graph *const graph, Operation *const operation); Operation *ConvertScatter(Graph *const graph, Operation *const operation); @@ -240,7 +239,7 @@ public: { {}, { - &TFLiteGraphOptimiser::RemoveReshape, + &GraphOptimiser::RemoveReshape, } }, { -- GitLab