diff --git a/ethosu/regor/compiler/graph.hpp b/ethosu/regor/compiler/graph.hpp index 90b3d72b3cea6a5acdcbe360304056bd36124123..6c8c31977b8597e6f37321ea526a1494b6296371 100644 --- a/ethosu/regor/compiler/graph.hpp +++ b/ethosu/regor/compiler/graph.hpp @@ -128,7 +128,7 @@ public: // Finds all operations which precede a graph output and adds them to the vector in execution order void GetAllOperations(std::vector &operations) const { - TraverseGraphFromEnd(Outputs(), + TraverseGraphFromEnd(Outputs(), !Persistent().empty(), [&](Operation *op) -> bool { operations.push_back(op); @@ -138,7 +138,7 @@ public: void GetAllOperations(std::vector> &operations) const { - TraverseGraphFromEnd(Outputs(), + TraverseGraphFromEnd(Outputs(), !Persistent().empty(), [&](Operation *op) -> bool { operations.push_back(op->shared_from_this()); @@ -151,16 +151,21 @@ public: void SetScheduledOrder(std::vector operations) { _opsInScheduledOrder = std::move(operations); } - // Traverse the graph in right-to-left reverse post-order but processing tensor writers left-to-right. - // This means in below graph, where A and B both write to the input tensor of C, A will be processed - // before B. + // Traverse the graph in right-to-left reverse post-order. + // + // TODO MLBEDSW-10790: Remove special handling of graphs with persistent tensors + // Special care is required if the graph contains any persistent tensors, as they may be written/read + // multiple times. To ensure a functional execution order in this case, the traversal will process tensor + // writers left-to-right. This preserves the order in whcih the operations were added to the graph. + // + // For example, in the graph below, where operations A and B both write to the input tensor of operation C. + // A will be processed before B (opposite of the default traversal). // A B // \ / // | // C - // The rationale is to preserve the order that partial writes are added to the graph. template - static void TraverseGraphFromEnd(const std::vector> &from, OPFUNC opFunc) + static void TraverseGraphFromEnd(const std::vector> &from, bool traverseLeftToRight, OPFUNC opFunc) { struct Entry { @@ -174,10 +179,21 @@ public: for ( const auto &tensor : from ) { - const auto &writers = tensor->Writers(); - for ( auto it = writers.crbegin(); it != writers.crend(); it++ ) + // TODO MLBEDSW-10790: Remove special handling of graphs with persistent tensors + if ( traverseLeftToRight ) { - stack.emplace(false, *it); + const auto &writers = tensor->Writers(); + for ( auto it = writers.crbegin(); it != writers.crend(); it++ ) + { + stack.emplace(false, *it); + } + } + else + { + for ( const auto &op : tensor->Writers() ) + { + stack.emplace(false, op); + } } } @@ -198,12 +214,26 @@ public: stack.emplace(true, entry.op); for ( const auto &pair : entry.op->Inputs().pairs() ) { - const auto &writers = pair.second.tensor->Writers(); - for ( auto it = writers.crbegin(); it != writers.crend(); it++ ) + // TODO MLBEDSW-10790: Remove special handling of graphs with persistent tensors + if ( traverseLeftToRight ) + { + const auto &writers = pair.second.tensor->Writers(); + for ( auto it = writers.crbegin(); it != writers.crend(); it++ ) + { + if ( visited.count(it->get()) == 0 ) + { + stack.emplace(false, *it); + } + } + } + else { - if ( visited.count(it->get()) == 0 ) + for ( const auto &op : pair.second.tensor->Writers() ) { - stack.emplace(false, *it); + if ( visited.count(op.get()) == 0 ) + { + stack.emplace(false, op); + } } } } diff --git a/ethosu/regor/compiler/graph_builder.cpp b/ethosu/regor/compiler/graph_builder.cpp index 3dd3ec06c5e221b8ea75e3754b7ed46816d1ede2..7a11736ae899595e52c14fa336fe110d152159a5 100644 --- a/ethosu/regor/compiler/graph_builder.cpp +++ b/ethosu/regor/compiler/graph_builder.cpp @@ -548,7 +548,7 @@ void GraphBuilder::FreeUnconnected() { // In case somebody added self-supporting graph fragments std::unordered_set connected; - Graph::TraverseGraphFromEnd(_outputs, + Graph::TraverseGraphFromEnd(_outputs, !_persistent.empty(), [&](Operation *op) -> bool { connected.insert(op); diff --git a/ethosu/regor/compiler/graph_optimiser.hpp b/ethosu/regor/compiler/graph_optimiser.hpp index 0d844cf7b42032ac739efaf635d69d5478fe5294..f4c799c553ed012adad3293cdad19016213ea495 100644 --- a/ethosu/regor/compiler/graph_optimiser.hpp +++ b/ethosu/regor/compiler/graph_optimiser.hpp @@ -110,7 +110,7 @@ public: // Traverse from End and collect operators that are at the start of the graph. Their inputs are only either // - Constant // - Graph inputs - Graph::TraverseGraphFromEnd(graph->Outputs(), + Graph::TraverseGraphFromEnd(graph->Outputs(), !graph->Persistent().empty(), [&](Operation *op) -> bool { for ( auto [usage, ifmConn] : op->Inputs().pairs() ) diff --git a/ethosu/regor/compiler/scheduler_packing.cpp b/ethosu/regor/compiler/scheduler_packing.cpp index d625771994bd00d439c82e8ea7453c8ad1ff9d8c..69740d3b1580f827aeae38ce23e4bd19d5fe201e 100644 --- a/ethosu/regor/compiler/scheduler_packing.cpp +++ b/ethosu/regor/compiler/scheduler_packing.cpp @@ -113,7 +113,7 @@ std::vector> SchedulerPacking::Process(const { // Get operation list in execution order std::vector executionList; - Graph::TraverseGraphFromEnd(graph->Outputs(), + Graph::TraverseGraphFromEnd(graph->Outputs(), !graph->Persistent().empty(), [&](Operation *op) -> bool { executionList.push_back(op); @@ -481,11 +481,6 @@ int SchedulerPacking::CanPack(const SchedulerOperation *schedOp, const Scheduler return 0; } - if ( schedOp->Type() == OpType::FullyConnected ) - { - return 0; - } - // Do not pack persistent tensors with non persistent tensors if ( prevOFM->isPersistent != nextOp->OFM()->tensor->isPersistent ) { diff --git a/ethosu/regor/compiler/tosa_graph_validator.cpp b/ethosu/regor/compiler/tosa_graph_validator.cpp index ec6c899023badbc1a6f4094728bb1f3bfe936583..65d0846a6ecd35b621188bff0c8cfb6f259b9ae4 100644 --- a/ethosu/regor/compiler/tosa_graph_validator.cpp +++ b/ethosu/regor/compiler/tosa_graph_validator.cpp @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2023-2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2023-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -72,7 +72,7 @@ TosaGraphValidator::TosaGraphValidator(GraphNotation notation, uint32_t syntaxVe bool TosaGraphValidator::Validate(Graph *graph) { bool graphValid = true; - Graph::TraverseGraphFromEnd(graph->Outputs(), + Graph::TraverseGraphFromEnd(graph->Outputs(), !graph->Persistent().empty(), [&graphValid, &graph, this](Operation *op) -> bool { try