From 9e8131008319fccf4a7e1819d7b3e1e37fb27e10 Mon Sep 17 00:00:00 2001 From: Johan Gunnarsson Date: Thu, 12 Dec 2024 14:41:25 +0100 Subject: [PATCH] MLBEDSW-10167: Write out all input tensors to passthrough ops We don't have tensor mappings for passthrough ops, so we should write out all input tensors in the order we read them in. Before this patch we only wrote out the IFM tensors. Not the Params tensors. Signed-off-by: Johan Gunnarsson Change-Id: Ibb9b8c88a96e9cc9267769af54c6b09faceaa223 --- ethosu/regor/tflite/tflite_writer.cpp | 30 +++++++++++++++++++-------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/ethosu/regor/tflite/tflite_writer.cpp b/ethosu/regor/tflite/tflite_writer.cpp index 79005e23..db70f242 100644 --- a/ethosu/regor/tflite/tflite_writer.cpp +++ b/ethosu/regor/tflite/tflite_writer.cpp @@ -239,18 +239,30 @@ std::vector TfLiteWriter::SortedInputTensors(const Operation *op { std::vector tensors; - int ifm = 0; - for ( const auto &pair : TfLiteMapping::InputTensorIndices(type) ) + const auto tensorIndices = TfLiteMapping::InputTensorIndices(type); + if ( tensorIndices.begin() != tensorIndices.end() ) { - const TensorUsage usage = pair.second; - const auto conn = operation->Input(usage); - tensors.push_back(conn ? conn->tensor.get() : nullptr); - ifm += IsIFM(usage); + // If we have tensor indices for this op type, use that tensor order + int ifm = 0; + for ( const auto &[type_, usage] : tensorIndices ) + { + const auto conn = operation->Input(usage); + tensors.push_back(conn ? conn->tensor.get() : nullptr); + ifm += IsIFM(usage); + } + while ( operation->Input(MakeTensorUsage(TensorUsage::IFM, ifm)) ) + { + tensors.push_back(operation->IFM(ifm)); + ifm++; + } } - while ( operation->Input(MakeTensorUsage(TensorUsage::IFM, ifm)) ) + else { - tensors.push_back(operation->IFM(ifm)); - ifm++; + // If we don't have tensor indices for this op type, use the tensor order we have + for ( const auto &[usage, conn] : operation->Inputs().pairs() ) + { + tensors.push_back(conn.tensor.get()); + } } return tensors; } -- GitLab