From 9eacbf17c93151a3317077ccdfd19576b41c5074 Mon Sep 17 00:00:00 2001 From: Johan Gunnarsson Date: Sun, 5 Jan 2025 12:39:28 +0100 Subject: [PATCH] MLBEDSW-10213: Handle networks without graph inputs * Don't remove TOSA IDENTITY ops with constant IFM. * Don't expect TOSA graph inputs. Signed-off-by: Johan Gunnarsson Change-Id: I779d03e378741d3617c9794020185d38508d1803 --- ethosu/regor/compiler/graphir_optimiser.cpp | 5 +++-- ethosu/regor/tosa/tosa_reader.cpp | 9 ++++++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/ethosu/regor/compiler/graphir_optimiser.cpp b/ethosu/regor/compiler/graphir_optimiser.cpp index d24def0e..7928894b 100644 --- a/ethosu/regor/compiler/graphir_optimiser.cpp +++ b/ethosu/regor/compiler/graphir_optimiser.cpp @@ -383,14 +383,15 @@ Operation *GraphIrOptimiser::RemoveReshape(Graph *const graph, Operation *const auto *ifm = ifmConn->tensor.get(); auto *ofm = ofmConn->tensor.get(); - // Check if ifm/ofm are network ifm/ofm + // Check if ifm/ofm are network ifm/ofm or constant + bool isIfmConst = ifm->IsConstant(); bool isIfmSgIfm = IsTensorInVector(graph->Inputs(), ifm); bool isOfmSgOfm = IsTensorInVector(graph->Outputs(), ofm); bool isIfmSgOfm = IsTensorInVector(graph->Outputs(), ifm); // TODO: MLBEDSW-9069: Check CPU operator producer/consumer // Inserts a copy op if needed before removing reshapes. - if ( ((isIfmSgIfm || isIfmSgOfm) && (isOfmSgOfm)) || + if ( ((isIfmSgIfm || isIfmSgOfm || isIfmConst) && (isOfmSgOfm)) || ((ifm->Readers().size() > 1) && (ifm->StorageShape() != ofm->StorageShape() || ifm->AxisOrder() != ofm->AxisOrder())) ) { auto copyOp = InsertCopyOpAfterTensor(ifmConn->tensor, ifmConn->quantization); diff --git a/ethosu/regor/tosa/tosa_reader.cpp b/ethosu/regor/tosa/tosa_reader.cpp index b1aab0ba..582bf4db 100644 --- a/ethosu/regor/tosa/tosa_reader.cpp +++ b/ethosu/regor/tosa/tosa_reader.cpp @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -977,9 +977,12 @@ void TosaReader::LoadGraphs(const tosaFb::TosaGraph *model, std::listinputs()) ) + if ( tosa_basicblock->inputs() ) { - builder->AddInput(tensors.at(ten->str())); + for ( auto ten : SafeDeref(tosa_basicblock->inputs()) ) + { + builder->AddInput(tensors.at(ten->str())); + } } for ( auto ten : SafeDeref(tosa_basicblock->outputs()) ) { -- GitLab