diff --git a/ethosu/regor/compiler/graphir_optimiser.cpp b/ethosu/regor/compiler/graphir_optimiser.cpp index d24def0e1260b4890c2b6b773eb497e62d00c4c5..7928894b64629a4060b547e586fe8bb54e2a5fc1 100644 --- a/ethosu/regor/compiler/graphir_optimiser.cpp +++ b/ethosu/regor/compiler/graphir_optimiser.cpp @@ -383,14 +383,15 @@ Operation *GraphIrOptimiser::RemoveReshape(Graph *const graph, Operation *const auto *ifm = ifmConn->tensor.get(); auto *ofm = ofmConn->tensor.get(); - // Check if ifm/ofm are network ifm/ofm + // Check if ifm/ofm are network ifm/ofm or constant + bool isIfmConst = ifm->IsConstant(); bool isIfmSgIfm = IsTensorInVector(graph->Inputs(), ifm); bool isOfmSgOfm = IsTensorInVector(graph->Outputs(), ofm); bool isIfmSgOfm = IsTensorInVector(graph->Outputs(), ifm); // TODO: MLBEDSW-9069: Check CPU operator producer/consumer // Inserts a copy op if needed before removing reshapes. - if ( ((isIfmSgIfm || isIfmSgOfm) && (isOfmSgOfm)) || + if ( ((isIfmSgIfm || isIfmSgOfm || isIfmConst) && (isOfmSgOfm)) || ((ifm->Readers().size() > 1) && (ifm->StorageShape() != ofm->StorageShape() || ifm->AxisOrder() != ofm->AxisOrder())) ) { auto copyOp = InsertCopyOpAfterTensor(ifmConn->tensor, ifmConn->quantization); diff --git a/ethosu/regor/tosa/tosa_reader.cpp b/ethosu/regor/tosa/tosa_reader.cpp index b1aab0baf9c543a6bcb3eff83dfe3aff511d4a2e..582bf4dbb4f54917de3a54bfcfe82aee83f2b7f9 100644 --- a/ethosu/regor/tosa/tosa_reader.cpp +++ b/ethosu/regor/tosa/tosa_reader.cpp @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -977,9 +977,12 @@ void TosaReader::LoadGraphs(const tosaFb::TosaGraph *model, std::listinputs()) ) + if ( tosa_basicblock->inputs() ) { - builder->AddInput(tensors.at(ten->str())); + for ( auto ten : SafeDeref(tosa_basicblock->inputs()) ) + { + builder->AddInput(tensors.at(ten->str())); + } } for ( auto ten : SafeDeref(tosa_basicblock->outputs()) ) {