From 319b4610858418be8cb4865513e6c17b88d6aab7 Mon Sep 17 00:00:00 2001 From: Jacob Bohlin Date: Mon, 12 May 2025 11:52:52 +0100 Subject: [PATCH] MLBEDSW-10776 Revert graph traversal change The graph traversal was modified in MLBEDSW-8926 to traverse tensor writers left-to-right instead of right-to-left. This is required to ensure correct execution order for LSTM. This change reverts the graph traversal back to right-to-left in the general case and left-to-right will only be used on graphs which contain persistent tensors, in order to target LSTM operators. Change-Id: Ibe13af0cf952450cff253ff2e44ee8b96068583a Signed-off-by: Jacob Bohlin --- ethosu/regor/compiler/graph.hpp | 58 ++++++++++++++----- ethosu/regor/compiler/graph_builder.cpp | 2 +- ethosu/regor/compiler/graph_optimiser.hpp | 2 +- ethosu/regor/compiler/scheduler_packing.cpp | 7 +-- .../regor/compiler/tosa_graph_validator.cpp | 4 +- 5 files changed, 49 insertions(+), 24 deletions(-) diff --git a/ethosu/regor/compiler/graph.hpp b/ethosu/regor/compiler/graph.hpp index 90b3d72b..6c8c3197 100644 --- a/ethosu/regor/compiler/graph.hpp +++ b/ethosu/regor/compiler/graph.hpp @@ -128,7 +128,7 @@ public: // Finds all operations which precede a graph output and adds them to the vector in execution order void GetAllOperations(std::vector &operations) const { - TraverseGraphFromEnd(Outputs(), + TraverseGraphFromEnd(Outputs(), !Persistent().empty(), [&](Operation *op) -> bool { operations.push_back(op); @@ -138,7 +138,7 @@ public: void GetAllOperations(std::vector> &operations) const { - TraverseGraphFromEnd(Outputs(), + TraverseGraphFromEnd(Outputs(), !Persistent().empty(), [&](Operation *op) -> bool { operations.push_back(op->shared_from_this()); @@ -151,16 +151,21 @@ public: void SetScheduledOrder(std::vector operations) { _opsInScheduledOrder = std::move(operations); } - // Traverse the graph in right-to-left reverse post-order but processing tensor writers left-to-right. - // This means in below graph, where A and B both write to the input tensor of C, A will be processed - // before B. + // Traverse the graph in right-to-left reverse post-order. + // + // TODO MLBEDSW-10790: Remove special handling of graphs with persistent tensors + // Special care is required if the graph contains any persistent tensors, as they may be written/read + // multiple times. To ensure a functional execution order in this case, the traversal will process tensor + // writers left-to-right. This preserves the order in whcih the operations were added to the graph. + // + // For example, in the graph below, where operations A and B both write to the input tensor of operation C. + // A will be processed before B (opposite of the default traversal). // A B // \ / // | // C - // The rationale is to preserve the order that partial writes are added to the graph. template - static void TraverseGraphFromEnd(const std::vector> &from, OPFUNC opFunc) + static void TraverseGraphFromEnd(const std::vector> &from, bool traverseLeftToRight, OPFUNC opFunc) { struct Entry { @@ -174,10 +179,21 @@ public: for ( const auto &tensor : from ) { - const auto &writers = tensor->Writers(); - for ( auto it = writers.crbegin(); it != writers.crend(); it++ ) + // TODO MLBEDSW-10790: Remove special handling of graphs with persistent tensors + if ( traverseLeftToRight ) { - stack.emplace(false, *it); + const auto &writers = tensor->Writers(); + for ( auto it = writers.crbegin(); it != writers.crend(); it++ ) + { + stack.emplace(false, *it); + } + } + else + { + for ( const auto &op : tensor->Writers() ) + { + stack.emplace(false, op); + } } } @@ -198,12 +214,26 @@ public: stack.emplace(true, entry.op); for ( const auto &pair : entry.op->Inputs().pairs() ) { - const auto &writers = pair.second.tensor->Writers(); - for ( auto it = writers.crbegin(); it != writers.crend(); it++ ) + // TODO MLBEDSW-10790: Remove special handling of graphs with persistent tensors + if ( traverseLeftToRight ) + { + const auto &writers = pair.second.tensor->Writers(); + for ( auto it = writers.crbegin(); it != writers.crend(); it++ ) + { + if ( visited.count(it->get()) == 0 ) + { + stack.emplace(false, *it); + } + } + } + else { - if ( visited.count(it->get()) == 0 ) + for ( const auto &op : pair.second.tensor->Writers() ) { - stack.emplace(false, *it); + if ( visited.count(op.get()) == 0 ) + { + stack.emplace(false, op); + } } } } diff --git a/ethosu/regor/compiler/graph_builder.cpp b/ethosu/regor/compiler/graph_builder.cpp index 3dd3ec06..7a11736a 100644 --- a/ethosu/regor/compiler/graph_builder.cpp +++ b/ethosu/regor/compiler/graph_builder.cpp @@ -548,7 +548,7 @@ void GraphBuilder::FreeUnconnected() { // In case somebody added self-supporting graph fragments std::unordered_set connected; - Graph::TraverseGraphFromEnd(_outputs, + Graph::TraverseGraphFromEnd(_outputs, !_persistent.empty(), [&](Operation *op) -> bool { connected.insert(op); diff --git a/ethosu/regor/compiler/graph_optimiser.hpp b/ethosu/regor/compiler/graph_optimiser.hpp index 0d844cf7..f4c799c5 100644 --- a/ethosu/regor/compiler/graph_optimiser.hpp +++ b/ethosu/regor/compiler/graph_optimiser.hpp @@ -110,7 +110,7 @@ public: // Traverse from End and collect operators that are at the start of the graph. Their inputs are only either // - Constant // - Graph inputs - Graph::TraverseGraphFromEnd(graph->Outputs(), + Graph::TraverseGraphFromEnd(graph->Outputs(), !graph->Persistent().empty(), [&](Operation *op) -> bool { for ( auto [usage, ifmConn] : op->Inputs().pairs() ) diff --git a/ethosu/regor/compiler/scheduler_packing.cpp b/ethosu/regor/compiler/scheduler_packing.cpp index d6257719..69740d3b 100644 --- a/ethosu/regor/compiler/scheduler_packing.cpp +++ b/ethosu/regor/compiler/scheduler_packing.cpp @@ -113,7 +113,7 @@ std::vector> SchedulerPacking::Process(const { // Get operation list in execution order std::vector executionList; - Graph::TraverseGraphFromEnd(graph->Outputs(), + Graph::TraverseGraphFromEnd(graph->Outputs(), !graph->Persistent().empty(), [&](Operation *op) -> bool { executionList.push_back(op); @@ -481,11 +481,6 @@ int SchedulerPacking::CanPack(const SchedulerOperation *schedOp, const Scheduler return 0; } - if ( schedOp->Type() == OpType::FullyConnected ) - { - return 0; - } - // Do not pack persistent tensors with non persistent tensors if ( prevOFM->isPersistent != nextOp->OFM()->tensor->isPersistent ) { diff --git a/ethosu/regor/compiler/tosa_graph_validator.cpp b/ethosu/regor/compiler/tosa_graph_validator.cpp index ec6c8990..65d0846a 100644 --- a/ethosu/regor/compiler/tosa_graph_validator.cpp +++ b/ethosu/regor/compiler/tosa_graph_validator.cpp @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2023-2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2023-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -72,7 +72,7 @@ TosaGraphValidator::TosaGraphValidator(GraphNotation notation, uint32_t syntaxVe bool TosaGraphValidator::Validate(Graph *graph) { bool graphValid = true; - Graph::TraverseGraphFromEnd(graph->Outputs(), + Graph::TraverseGraphFromEnd(graph->Outputs(), !graph->Persistent().empty(), [&graphValid, &graph, this](Operation *op) -> bool { try -- GitLab