From 26b2f3b23a249b844340f7da52a67b46db3bee42 Mon Sep 17 00:00:00 2001 From: Johan Gunnarsson Date: Wed, 21 May 2025 13:07:46 +0200 Subject: [PATCH] MLBEDSW-9946: Add passthrough for TFLite metadata Signed-off-by: Johan Gunnarsson Change-Id: Iab867cee293980582072e9cb8eb9c9cd38ad8169 --- ethosu/regor/test/test_passthrough.cpp | 7 ++---- ethosu/regor/tflite/tflite_writer.cpp | 35 ++++++++++++++++++++++++-- 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/ethosu/regor/test/test_passthrough.cpp b/ethosu/regor/test/test_passthrough.cpp index 10c13f6f..a6126d8b 100644 --- a/ethosu/regor/test/test_passthrough.cpp +++ b/ethosu/regor/test/test_passthrough.cpp @@ -533,14 +533,11 @@ TEST_CASE("passthrough") tensors.push_back(tflite::CreateTensorDirect(fbb, &shape, type, bufferIndex, name.c_str())); } - /* - // TODO: MLBEDSW-9946: Metadata passthrough is currently not supported { std::vector data2 = random_vector(5, 0, 255); - serialised_buffers.push_back(tflite::CreateBufferDirect(fbb, &data2)); - serialised_metadata.push_back(tflite::CreateMetadataDirect(fbb, "metadata1", serialised_buffers.size() - 1)); + buffers.push_back(tflite::CreateBufferDirect(fbb, &data2)); + metadata.push_back(tflite::CreateMetadataDirect(fbb, "metadata1", uint32_t(buffers.size() - 1))); } - */ { // Generate 1 operator diff --git a/ethosu/regor/tflite/tflite_writer.cpp b/ethosu/regor/tflite/tflite_writer.cpp index a0e58769..8db24ba2 100644 --- a/ethosu/regor/tflite/tflite_writer.cpp +++ b/ethosu/regor/tflite/tflite_writer.cpp @@ -87,7 +87,7 @@ std::unique_ptr TfLiteWriter::SerialiseImpl(const std::vector> serialised_metadata; // TODO: passthrough metadata + std::vector> serialised_metadata; for ( const auto &graph : graphs ) { @@ -256,7 +256,38 @@ std::unique_ptr TfLiteWriter::SerialiseImpl(const std::vector 0 ) + { + const auto tflite_model = static_cast(graphs[0]->Passthrough()); + if ( tflite_model ) + { + const auto *tflite_metadata = tflite_model->metadata(); + const auto *tflite_buffers = tflite_model->buffers(); + if ( tflite_metadata && tflite_buffers ) + { + for ( auto it = tflite_metadata->begin(); it != tflite_metadata->end(); it++ ) + { + const auto buffer = (*it)->buffer(); + if ( buffer >= tflite_buffers->size() ) continue; // Invalid buffer + const auto name = (*it)->name(); + if ( !name ) continue; // Invalid name + const auto data = FlatbufferUtils::CopyVector(_flatbuffer, tflite_buffers->Get(buffer)->data()); + const auto offset = tflite_buffers->Get(buffer)->offset(); + const auto size = tflite_buffers->Get(buffer)->size(); + // Copy buffer + _serialised_buffers.push_back(tflite::CreateBuffer(_flatbuffer, data, offset, size)); + // Copy metadata + serialised_metadata.push_back(tflite::CreateMetadata( + _flatbuffer, _flatbuffer.CreateString(name), uint32_t(_serialised_buffers.size() - 1))); + // If we copied a OfflineMemoryAllocation, don't create a new one later on + if ( name->str() == "OfflineMemoryAllocation" ) hasOfflineMemoryAllocation = true; + } + } + } + } + + if ( !_skipOfflineMemoryAllocation && !hasOfflineMemoryAllocation ) { serialised_metadata.push_back(SerialiseTensorAddresses(int(_serialised_subgraphs.size()))); } -- GitLab