diff --git a/ethosu/regor/test/test_passthrough.cpp b/ethosu/regor/test/test_passthrough.cpp index 10c13f6f66c07c567f0fd14cfa9a7f72aa68f5ac..a6126d8bbdec7680aaa95ad9d9182910614e2e45 100644 --- a/ethosu/regor/test/test_passthrough.cpp +++ b/ethosu/regor/test/test_passthrough.cpp @@ -533,14 +533,11 @@ TEST_CASE("passthrough") tensors.push_back(tflite::CreateTensorDirect(fbb, &shape, type, bufferIndex, name.c_str())); } - /* - // TODO: MLBEDSW-9946: Metadata passthrough is currently not supported { std::vector data2 = random_vector(5, 0, 255); - serialised_buffers.push_back(tflite::CreateBufferDirect(fbb, &data2)); - serialised_metadata.push_back(tflite::CreateMetadataDirect(fbb, "metadata1", serialised_buffers.size() - 1)); + buffers.push_back(tflite::CreateBufferDirect(fbb, &data2)); + metadata.push_back(tflite::CreateMetadataDirect(fbb, "metadata1", uint32_t(buffers.size() - 1))); } - */ { // Generate 1 operator diff --git a/ethosu/regor/tflite/tflite_writer.cpp b/ethosu/regor/tflite/tflite_writer.cpp index a0e58769f4e41f137bc95e00f032ee6b680dfad5..8db24ba2b972336064aa3274bf707f3c96d08451 100644 --- a/ethosu/regor/tflite/tflite_writer.cpp +++ b/ethosu/regor/tflite/tflite_writer.cpp @@ -87,7 +87,7 @@ std::unique_ptr TfLiteWriter::SerialiseImpl(const std::vector> serialised_metadata; // TODO: passthrough metadata + std::vector> serialised_metadata; for ( const auto &graph : graphs ) { @@ -256,7 +256,38 @@ std::unique_ptr TfLiteWriter::SerialiseImpl(const std::vector 0 ) + { + const auto tflite_model = static_cast(graphs[0]->Passthrough()); + if ( tflite_model ) + { + const auto *tflite_metadata = tflite_model->metadata(); + const auto *tflite_buffers = tflite_model->buffers(); + if ( tflite_metadata && tflite_buffers ) + { + for ( auto it = tflite_metadata->begin(); it != tflite_metadata->end(); it++ ) + { + const auto buffer = (*it)->buffer(); + if ( buffer >= tflite_buffers->size() ) continue; // Invalid buffer + const auto name = (*it)->name(); + if ( !name ) continue; // Invalid name + const auto data = FlatbufferUtils::CopyVector(_flatbuffer, tflite_buffers->Get(buffer)->data()); + const auto offset = tflite_buffers->Get(buffer)->offset(); + const auto size = tflite_buffers->Get(buffer)->size(); + // Copy buffer + _serialised_buffers.push_back(tflite::CreateBuffer(_flatbuffer, data, offset, size)); + // Copy metadata + serialised_metadata.push_back(tflite::CreateMetadata( + _flatbuffer, _flatbuffer.CreateString(name), uint32_t(_serialised_buffers.size() - 1))); + // If we copied a OfflineMemoryAllocation, don't create a new one later on + if ( name->str() == "OfflineMemoryAllocation" ) hasOfflineMemoryAllocation = true; + } + } + } + } + + if ( !_skipOfflineMemoryAllocation && !hasOfflineMemoryAllocation ) { serialised_metadata.push_back(SerialiseTensorAddresses(int(_serialised_subgraphs.size()))); }