From 88d7a53f1401f054cd4c4463e0927c9c9979b7d7 Mon Sep 17 00:00:00 2001 From: Johan Gunnarsson Date: Thu, 10 Apr 2025 15:12:29 +0200 Subject: [PATCH 1/2] MLBEDSW-10656: Change tflite_writer to copy tables using reflection This changes tflite_writer so that is copies passthrough tables using flatbuffer minireflection. Tables affected by this change are: * tensor.quantization * tensor.sparsity * tensor.variant_tensors * operator.builtin_options * operator.builtin_options_2 Signed-off-by: Johan Gunnarsson Change-Id: Ia197694a09d5eb97480f303919afd209e8221332 --- ethosu/regor/tflite/flatbuffer_utils.hpp | 291 ++++++++++++- ethosu/regor/tflite/tflite_writer.cpp | 494 ++--------------------- 2 files changed, 330 insertions(+), 455 deletions(-) diff --git a/ethosu/regor/tflite/flatbuffer_utils.hpp b/ethosu/regor/tflite/flatbuffer_utils.hpp index 5b29809c..796e03cc 100644 --- a/ethosu/regor/tflite/flatbuffer_utils.hpp +++ b/ethosu/regor/tflite/flatbuffer_utils.hpp @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2021, 2023, 2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -18,13 +18,19 @@ #pragma once +#include "common/logging.hpp" + #include namespace FlatbufferUtils { + +static flatbuffers::Offset<> CopyTable(flatbuffers::FlatBufferBuilder &destination, const flatbuffers::Table *source, + const flatbuffers::TypeTable *typeTable); + // Load a vector (if present) from a flatbuffer into a local copy. // Intended for small vectors only - large vectors should be left in place and mapped using a Buffer class instead. -template +template, int> = 0> static std::vector LoadVector(const flatbuffers::Vector *source) { std::vector destination; @@ -35,8 +41,8 @@ static std::vector LoadVector(const flatbuffers::Vector *source) return destination; } -// Copy a vector (if present) from one flatbuffer to another, returning the offset into the destination buffer. -template +// Copy a vector of scalars (if present) from one flatbuffer to another +template, int> = 0> static flatbuffers::Offset> CopyVector(flatbuffers::FlatBufferBuilder &destination, const flatbuffers::Vector *source) { @@ -44,6 +50,283 @@ CopyVector(flatbuffers::FlatBufferBuilder &destination, const flatbuffers::Vecto { return destination.CreateVector(source->data(), source->size()); } + + // Zero offset means unset or non-existing vector + return 0; +} + +// Make a copy of a repeating table field, by field +static flatbuffers::Offset>> CopyVectorOfTables(flatbuffers::FlatBufferBuilder &destination, + const flatbuffers::Table *source, flatbuffers::voffset_t field, const flatbuffers::TypeTable *types) +{ + std::vector> dstVector; + if ( source ) + { + const auto *srcVector = source->GetPointer> *>(field); + if ( srcVector ) + { + for ( const auto *table : *srcVector ) + { + dstVector.push_back(CopyTable(destination, table, types)); + } + } + } + + // Create a new vector + return destination.CreateVector(dstVector); +} + +// Make a copy of a repeating string field, by field +static flatbuffers::Offset>> CopyVectorOfStrings( + flatbuffers::FlatBufferBuilder &destination, const flatbuffers::Table *source, flatbuffers::voffset_t field) +{ + std::vector> dstVector; + if ( source ) + { + const auto srcVector = source->GetPointer> *>(field); + if ( srcVector ) + { + for ( const auto *str : *srcVector ) + { + dstVector.push_back(destination.CreateString(str)); + } + } + } + + // Create a new vector + return destination.CreateVector(dstVector); +} + +// Copy a vector of scalars (if present) from one flatbuffer to another, by field +template, int> = 0> +static flatbuffers::Offset> +CopyVector(flatbuffers::FlatBufferBuilder &destination, const flatbuffers::Table *source, flatbuffers::voffset_t field) +{ + if ( source && source->CheckField(field) ) + { + return CopyVector(destination, source->GetPointer *>(field)); + } + + // Zero offset means unset or non-existing vector + return 0; +} + +// Copy a vector of scalars (if present) from one flatbuffer to another, by field +static flatbuffers::Offset<> CopyVectorOfScalars(flatbuffers::FlatBufferBuilder &destination, + const flatbuffers::Table *source, flatbuffers::voffset_t field, flatbuffers::ElementaryType type) +{ + if ( source && source->CheckField(field) ) + { + if ( type == flatbuffers::ET_BOOL ) return CopyVector(destination, source, field).Union(); + else if ( type == flatbuffers::ET_CHAR ) return CopyVector(destination, source, field).Union(); + else if ( type == flatbuffers::ET_UCHAR ) return CopyVector(destination, source, field).Union(); + else if ( type == flatbuffers::ET_SHORT ) return CopyVector(destination, source, field).Union(); + else if ( type == flatbuffers::ET_USHORT ) return CopyVector(destination, source, field).Union(); + else if ( type == flatbuffers::ET_INT ) return CopyVector(destination, source, field).Union(); + else if ( type == flatbuffers::ET_UINT ) return CopyVector(destination, source, field).Union(); + else if ( type == flatbuffers::ET_LONG ) return CopyVector(destination, source, field).Union(); + else if ( type == flatbuffers::ET_ULONG ) return CopyVector(destination, source, field).Union(); + else if ( type == flatbuffers::ET_FLOAT ) return CopyVector(destination, source, field).Union(); + else if ( type == flatbuffers::ET_DOUBLE ) return CopyVector(destination, source, field).Union(); + else if ( type == flatbuffers::ET_STRING ) return CopyVectorOfStrings(destination, source, field).Union(); + else assert(false && "Unsupported elementary type"); + } + + // Zero offset means unset or non-existing vector + return 0; +} + +// Copy a scalar (if present) from one flatbuffer to another, by field +template, int> = 0> +static void CopyScalar(flatbuffers::FlatBufferBuilder &destination, const flatbuffers::Table *source, flatbuffers::voffset_t field) +{ + if ( source && source->CheckField(field) ) + { + destination.AddElement(field, source->GetField(field, 0)); + } +} + +// Copy a string (if present) from one flatbuffer to another, by field +static flatbuffers::Offset +CopyString(flatbuffers::FlatBufferBuilder &destination, const flatbuffers::Table *source, flatbuffers::voffset_t field) +{ + if ( source && source->CheckField(field) ) + { + return destination.CreateString(source->GetPointer(field)); + } + + // Zero offset means unset or non-existing string + return 0; +} + +// Copy a table (if present) from one flatbuffer to another by field +static flatbuffers::Offset<> CopyTable(flatbuffers::FlatBufferBuilder &destination, const flatbuffers::Table *source, + flatbuffers::voffset_t field, const flatbuffers::TypeTable *types) +{ + if ( source && source->CheckField(field) ) + { + return CopyTable(destination, source->GetPointer(field), types); + } + + // Zero offset means unset or non-existing table return 0; } + +// Copy a table +static flatbuffers::Offset<> CopyTable(flatbuffers::FlatBufferBuilder &destination, const flatbuffers::Table *source, + const flatbuffers::TypeTable *typeTable) +{ + // This function copies a flatbuffer Table by iterating over the type table two times. First time to copy + // offsets to all non-scalar, repeating types and strings, then a second time to add all offsets and scalars. + + if ( !source ) return flatbuffers::Offset<>(); + + // Can only copy tables + assert(typeTable); + assert(typeTable->st == flatbuffers::ST_TABLE); + + std::unordered_map> fieldToOffset; + + // Iterate over all types and create offsets for non-scalar items + for ( flatbuffers::voffset_t i = 0; i < typeTable->num_elems; i++ ) + { + const auto name = typeTable->names[i]; + const auto type = flatbuffers::ElementaryType(typeTable->type_codes[i].base_type); + const auto isRepeating = typeTable->type_codes[i].is_repeating != 0; + const auto sequenceRef = typeTable->type_codes[i].sequence_ref; + const auto field = flatbuffers::FieldIndexToOffset(i); + + if ( isRepeating ) + { + if ( sequenceRef >= 0 ) + { + LOG_TRACE1("Copy repeating {} (vector of {})\n", name, flatbuffers::ElementaryTypeNames()[type]); + + const auto *sequenceTypeTable = typeTable->type_refs[sequenceRef](); + assert(sequenceTypeTable); + + if ( type == flatbuffers::ET_SEQUENCE ) + fieldToOffset[field] = CopyVectorOfTables(destination, source, field, sequenceTypeTable).Union(); + else fieldToOffset[field] = CopyVectorOfScalars(destination, source, field, type); + } + else + { + LOG_TRACE1("Copy repeating {} (vector of {})\n", name, flatbuffers::ElementaryTypeNames()[type]); + + // Vector of scalars/strings + fieldToOffset[field] = CopyVectorOfScalars(destination, source, field, type); + } + } + else if ( sequenceRef >= 0 ) + { + const auto *sequenceTypeTable = typeTable->type_refs[sequenceRef](); + assert(sequenceTypeTable); + + if ( sequenceTypeTable->st == flatbuffers::ST_TABLE ) + { + LOG_TRACE1("Copy non-repeating table {} ({})\n", name, flatbuffers::ElementaryTypeNames()[type]); + + fieldToOffset[field] = CopyTable(destination, source, field, sequenceTypeTable); + } + else if ( sequenceTypeTable->st == flatbuffers::ST_UNION ) + { + LOG_TRACE1("Copy non-repeating union {} ({})\n", name, flatbuffers::ElementaryTypeNames()[type]); + + if ( type == flatbuffers::ET_SEQUENCE ) + { + const auto unionType = source->GetField(flatbuffers::FieldIndexToOffset(i - 1u), 0); + if ( unionType > 0 ) + fieldToOffset[field] = CopyTable(destination, source, field, sequenceTypeTable->type_refs[unionType - 1]()); + else fieldToOffset[field] = flatbuffers::Offset<>(); // Default NONE value + } + } + } + else if ( type == flatbuffers::ElementaryType::ET_STRING ) + { + LOG_TRACE1("Copy non-repeating non-sequence {} ({})\n", name, flatbuffers::ElementaryTypeNames()[type]); + + fieldToOffset[field] = CopyString(destination, source, field).Union(); + } + } + + const auto tableOffset = destination.StartTable(); + + // Iterate over all types and add offsets and scalar types + for ( flatbuffers::voffset_t i = 0; i < typeTable->num_elems; i++ ) + { + const auto name = typeTable->names[i]; + const auto baseType = flatbuffers::ElementaryType(typeTable->type_codes[i].base_type); + const auto isRepeating = typeTable->type_codes[i].is_repeating != 0; + const auto sequenceRef = typeTable->type_codes[i].sequence_ref; + const auto field = flatbuffers::FieldIndexToOffset(i); + + if ( fieldToOffset.count(field) == 0 ) + { + LOG_TRACE1("Copying and adding scalar {} ({})\n", name, flatbuffers::ElementaryTypeNames()[baseType]); + + // At this point it's too late for repeating types + assert(!isRepeating); + + if ( sequenceRef >= 0 ) + { + // At this point it's too late for any sequence types except for ENUM and the UNION type + const auto *sequenceTypeTable = typeTable->type_refs[sequenceRef](); + assert(sequenceTypeTable); + assert(sequenceTypeTable->st == flatbuffers::ST_ENUM || + (sequenceTypeTable->st == flatbuffers::ST_UNION && baseType == flatbuffers::ET_UTYPE)); + } + + // Scalar + if ( baseType == flatbuffers::ET_UTYPE ) CopyScalar(destination, source, field); + else if ( baseType == flatbuffers::ET_BOOL ) CopyScalar(destination, source, field); + else if ( baseType == flatbuffers::ET_CHAR ) CopyScalar(destination, source, field); + else if ( baseType == flatbuffers::ET_UCHAR ) CopyScalar(destination, source, field); + else if ( baseType == flatbuffers::ET_SHORT ) CopyScalar(destination, source, field); + else if ( baseType == flatbuffers::ET_USHORT ) CopyScalar(destination, source, field); + else if ( baseType == flatbuffers::ET_INT ) CopyScalar(destination, source, field); + else if ( baseType == flatbuffers::ET_UINT ) CopyScalar(destination, source, field); + else if ( baseType == flatbuffers::ET_LONG ) CopyScalar(destination, source, field); + else if ( baseType == flatbuffers::ET_ULONG ) CopyScalar(destination, source, field); + else if ( baseType == flatbuffers::ET_FLOAT ) CopyScalar(destination, source, field); + else if ( baseType == flatbuffers::ET_DOUBLE ) CopyScalar(destination, source, field); + else assert(false && "Unsupported elementary type"); + } + else + { + LOG_TRACE1("Adding offset {} ({})\n", name, flatbuffers::ElementaryTypeNames()[baseType]); + + destination.AddOffset(field, fieldToOffset[field]); + } + } + + return destination.EndTable(tableOffset); +} + +// Copy a table +template, int> = 0> +static flatbuffers::Offset CopyTable(flatbuffers::FlatBufferBuilder &destination, const T *source) +{ + const flatbuffers::Offset<> offset = CopyTable( + destination, reinterpret_cast(source), T::MiniReflectTypeTable()); + + // Special thing here to create an Offset object with a type + return flatbuffers::Offset(offset.o); +} + +// Copy a vector of tables +template, int> = 0> +static std::vector> +CopyVectorOfTables(flatbuffers::FlatBufferBuilder &destination, const flatbuffers::Vector> *source) +{ + std::vector> srcTables; + if ( source ) + { + for ( const auto *table : *source ) + { + srcTables.push_back(CopyTable(destination, table)); + } + } + return srcTables; +} + } // namespace FlatbufferUtils diff --git a/ethosu/regor/tflite/tflite_writer.cpp b/ethosu/regor/tflite/tflite_writer.cpp index b01ec963..18b540cc 100644 --- a/ethosu/regor/tflite/tflite_writer.cpp +++ b/ethosu/regor/tflite/tflite_writer.cpp @@ -29,23 +29,6 @@ #include #include -// Specialization for sparsity dimension metadata -template<> -flatbuffers::Offset>> FlatbufferUtils::CopyVector( - flatbuffers::FlatBufferBuilder &dst, const flatbuffers::Vector> *src) -{ - if ( src ) - { - std::vector> offsets; - for ( const auto &dimension_metadata : *src ) - { - // TODO: offsets.push_back(tflite::CreateDimensionMetadata(...)) - } - return dst.CreateVector>(offsets); - } - return 0; -} - namespace regor { @@ -408,59 +391,39 @@ int TfLiteWriter::SerialisedTensorIndex(const Tensor *tensor, const std::unorder flatbuffers::Offset TfLiteWriter::SerialiseTensor(const Tensor *tensor, const Graph &graph) { auto tflite_shape = tensor->StorageShape().ToList(); - std::vector quant_min; - std::vector quant_max; - std::vector scale_f32; - std::vector zeroPoints; - int dimension = 0; // Unused parameters are set to default or, if present in the input model, passed through unmodified - tflite::QuantizationDetails custom_quantization = tflite::QuantizationDetails::NONE; - flatbuffers::Offset custom_quantization_details = 0; bool is_variable = graph.IsPersistent(tensor); flatbuffers::Offset sparsity = 0; std::vector shape_signature; + bool has_rank = false; + std::vector> variant_tensors; + flatbuffers::Offset quantization = 0; if ( tensor->Passthrough() ) { const auto tflite_tensor = static_cast(tensor->Passthrough()); - const DataType type = TfLiteMapping::TensorTypeToDataType(tflite_tensor->type()); if ( tflite_tensor->quantization() ) { - if ( tflite_tensor->quantization()->scale() && tflite_tensor->quantization()->zero_point() ) - { - quant_min = FlatbufferUtils::LoadVector(tflite_tensor->quantization()->min()); - quant_max = FlatbufferUtils::LoadVector(tflite_tensor->quantization()->max()); - scale_f32 = FlatbufferUtils::LoadVector(tflite_tensor->quantization()->scale()); - zeroPoints = FlatbufferUtils::LoadVector(tflite_tensor->quantization()->zero_point()); - dimension = tflite_tensor->quantization()->quantized_dimension(); - } - - custom_quantization = tflite_tensor->quantization()->details_type(); - if ( custom_quantization == tflite::QuantizationDetails::CustomQuantization ) - { - if ( tflite_tensor->quantization()->details() ) - { - // TODO: custom_quantization_details - } - } + quantization = FlatbufferUtils::CopyTable(_flatbuffer, tflite_tensor->quantization()); } is_variable = tflite_tensor->is_variable(); if ( tflite_tensor->sparsity() ) { - auto traversal_order = FlatbufferUtils::CopyVector(_flatbuffer, tflite_tensor->sparsity()->traversal_order()); - auto block_map = FlatbufferUtils::CopyVector(_flatbuffer, tflite_tensor->sparsity()->block_map()); - auto dim_metadata = FlatbufferUtils::CopyVector>( - _flatbuffer, tflite_tensor->sparsity()->dim_metadata()); + sparsity = FlatbufferUtils::CopyTable(_flatbuffer, tflite_tensor->sparsity()); + } - sparsity = tflite::CreateSparsityParameters(_flatbuffer, traversal_order, block_map, dim_metadata); + if ( tflite_tensor->variant_tensors() ) + { + variant_tensors = FlatbufferUtils::CopyVectorOfTables(_flatbuffer, tflite_tensor->variant_tensors()); } shape_signature = FlatbufferUtils::LoadVector(tflite_tensor->shape_signature()); tflite_shape = FlatbufferUtils::LoadVector(tflite_tensor->shape()); + has_rank = tflite_tensor->has_rank(); } int buffer_index = 0; // Default to the empty buffer at index 0 @@ -481,16 +444,9 @@ flatbuffers::Offset TfLiteWriter::SerialiseTensor(const Tensor * } } - flatbuffers::Offset quantization = 0; - if ( !scale_f32.empty() && !zeroPoints.empty() ) - { - quantization = tflite::CreateQuantizationParametersDirect(_flatbuffer, &quant_min, &quant_max, &scale_f32, - &zeroPoints, custom_quantization, custom_quantization_details, dimension); - } - return tflite::CreateTensorDirect(_flatbuffer, tflite_shape.size() ? &tflite_shape : nullptr, TfLiteMapping::DataTypeToTensorType(tensor->Type()), buffer_index, tensor->Name().c_str(), quantization, - is_variable, sparsity, shape_signature.size() ? &shape_signature : nullptr); + is_variable, sparsity, shape_signature.size() ? &shape_signature : nullptr, has_rank, &variant_tensors); } template @@ -501,6 +457,7 @@ static const T *GetBuiltinOptions(const tflite::Operator *tflite_operator) return options; } +// Serialize builtin_options and return offset to it flatbuffers::Offset TfLiteWriter::SerialiseOptions(const Operation *operation, OpType opType) { if ( opType == OpType::CustomNpuOp ) @@ -508,395 +465,25 @@ flatbuffers::Offset TfLiteWriter::SerialiseOptions(const Operation *operat return 0; } - flatbuffers::Offset offset = 0; - const tflite::Operator *const passthrough = static_cast(operation->Passthrough()); - assert(passthrough); - const auto type = passthrough->builtin_options_type(); - - switch ( type ) + const auto tfliteOperator = static_cast(operation->Passthrough()); + assert(tfliteOperator); + const tflite::BuiltinOptions unionMemberType = tfliteOperator->builtin_options_type(); + if ( unionMemberType == tflite::BuiltinOptions::NONE ) { - case tflite::BuiltinOptions::NONE: - break; - - case tflite::BuiltinOptions::Conv2DOptions: - { - assert(passthrough->builtin_options_as_Conv2DOptions()); - tflite::ActivationFunctionType fused_activation_function = passthrough->builtin_options_as_Conv2DOptions()->fused_activation_function(); - const auto kernel = TfLiteKernel(*operation->Kernel()); - const auto typed_offset = tflite::CreateConv2DOptions(_flatbuffer, kernel.padding, kernel.stride_w, - kernel.stride_h, fused_activation_function, kernel.dilation_w_factor, kernel.dilation_h_factor); - offset = typed_offset.Union(); - } - break; - - case tflite::BuiltinOptions::DepthwiseConv2DOptions: - { - assert(passthrough->builtin_options_as_DepthwiseConv2DOptions()); - tflite::ActivationFunctionType fused_activation_function = - passthrough->builtin_options_as_DepthwiseConv2DOptions()->fused_activation_function(); - const auto kernel = TfLiteKernel(*operation->Kernel()); - const auto typed_offset = tflite::CreateDepthwiseConv2DOptions(_flatbuffer, kernel.padding, kernel.stride_w, - kernel.stride_h, kernel.depth_multiplier, fused_activation_function, kernel.dilation_w_factor, kernel.dilation_h_factor); - offset = typed_offset.Union(); - } - break; - - case tflite::BuiltinOptions::TransposeConvOptions: - { - assert(passthrough->builtin_options_as_TransposeConvOptions()); - tflite::ActivationFunctionType fused_activation_function = - passthrough->builtin_options_as_TransposeConvOptions()->fused_activation_function(); - - const auto kernel = TfLiteKernel(*operation->Kernel()); - const auto typed_offset = tflite::CreateTransposeConvOptions( - _flatbuffer, kernel.padding, kernel.stride_w, kernel.stride_h, fused_activation_function); - offset = typed_offset.Union(); - } - break; - - case tflite::BuiltinOptions::Pool2DOptions: - { - assert(passthrough->builtin_options_as_Pool2DOptions()); - tflite::ActivationFunctionType fused_activation_function = passthrough->builtin_options_as_Pool2DOptions()->fused_activation_function(); - const auto kernel = TfLiteKernel(*operation->Kernel()); - const auto typed_offset = tflite::CreatePool2DOptions(_flatbuffer, kernel.padding, kernel.stride_w, - kernel.stride_h, kernel.filter_w, kernel.filter_h, fused_activation_function); - offset = typed_offset.Union(); - } - break; - - case tflite::BuiltinOptions::FullyConnectedOptions: - { - assert(passthrough->builtin_options_as_FullyConnectedOptions()); - tflite::ActivationFunctionType fused_activation_function = - passthrough->builtin_options_as_FullyConnectedOptions()->fused_activation_function(); - const auto typed_offset = tflite::CreateFullyConnectedOptions(_flatbuffer, fused_activation_function - // TODO: weights_format, - // TODO: keep_num_dims, - // TODO: asymmetric_quantize_inputs - ); - offset = typed_offset.Union(); - } - break; - - case tflite::BuiltinOptions::SoftmaxOptions: - { - const auto options = GetBuiltinOptions(passthrough); - offset = tflite::CreateSoftmaxOptions(_flatbuffer, options->beta()).Union(); - } - break; - - case tflite::BuiltinOptions::ConcatenationOptions: - { - assert(passthrough->builtin_options_as_ConcatenationOptions()); - tflite::ActivationFunctionType fused_activation_function = - passthrough->builtin_options_as_ConcatenationOptions()->fused_activation_function(); - const auto options = GetBuiltinOptions(passthrough); - const auto typed_offset = tflite::CreateConcatenationOptions(_flatbuffer, options->axis(), fused_activation_function); - offset = typed_offset.Union(); - } - break; - - case tflite::BuiltinOptions::AddOptions: - { - assert(passthrough->builtin_options_as_AddOptions()); - tflite::ActivationFunctionType fused_activation_function = passthrough->builtin_options_as_AddOptions()->fused_activation_function(); - const auto typed_offset = tflite::CreateAddOptions(_flatbuffer, fused_activation_function, - GetBuiltinOptions(passthrough)->pot_scale_int16()); - offset = typed_offset.Union(); - } - break; - - case tflite::BuiltinOptions::SubOptions: - { - assert(passthrough->builtin_options_as_SubOptions()); - tflite::ActivationFunctionType fused_activation_function = passthrough->builtin_options_as_SubOptions()->fused_activation_function(); - const auto typed_offset = tflite::CreateSubOptions(_flatbuffer, fused_activation_function, - GetBuiltinOptions(passthrough)->pot_scale_int16()); - offset = typed_offset.Union(); - } - break; - - case tflite::BuiltinOptions::MulOptions: - { - assert(passthrough->builtin_options_as_MulOptions()); - tflite::ActivationFunctionType fused_activation_function = passthrough->builtin_options_as_MulOptions()->fused_activation_function(); - offset = tflite::CreateMulOptions(_flatbuffer, fused_activation_function).Union(); - } - break; - - case tflite::BuiltinOptions::DivOptions: - { - assert(passthrough->builtin_options_as_DivOptions()); - tflite::ActivationFunctionType fused_activation_function = passthrough->builtin_options_as_DivOptions()->fused_activation_function(); - offset = tflite::CreateDivOptions(_flatbuffer, fused_activation_function).Union(); - } - break; - - case tflite::BuiltinOptions::L2NormOptions: - { - assert(passthrough->builtin_options_as_L2NormOptions()); - tflite::ActivationFunctionType fused_activation_function = passthrough->builtin_options_as_L2NormOptions()->fused_activation_function(); - offset = tflite::CreateL2NormOptions(_flatbuffer, fused_activation_function).Union(); - } - break; - - case tflite::BuiltinOptions::ReshapeOptions: - { - // Replicate parameter tensor as ReshapeOptions - const auto tensor = operation->Input(TensorUsage::Params)->tensor; - if ( tensor->Type() != DataType::Int32 ) throw std::runtime_error("unexpected tensor type"); - - BufferReader reader = tensor->View().Values(); - std::vector temp; - std::copy(reader.begin(), reader.end(), std::back_inserter(temp)); - const auto new_shape = _flatbuffer.CreateVector(temp); - offset = tflite::CreateReshapeOptions(_flatbuffer, new_shape).Union(); - } - break; - - case tflite::BuiltinOptions::SqueezeOptions: - { - const auto options = GetBuiltinOptions(passthrough); - const auto typed_offset = tflite::CreateSqueezeOptions( - _flatbuffer, FlatbufferUtils::CopyVector(_flatbuffer, options->squeeze_dims())); - offset = typed_offset.Union(); - } - break; - - case tflite::BuiltinOptions::PackOptions: - { - const auto options = GetBuiltinOptions(passthrough); - const auto typed_offset = tflite::CreatePackOptions(_flatbuffer, options->values_count(), options->axis()); - offset = typed_offset.Union(); - } - break; - - case tflite::BuiltinOptions::UnpackOptions: - { - const auto options = GetBuiltinOptions(passthrough); - const auto typed_offset = tflite::CreateUnpackOptions(_flatbuffer, options->num(), options->axis()); - offset = typed_offset.Union(); - } - break; - - case tflite::BuiltinOptions::LeakyReluOptions: - { - const auto options = GetBuiltinOptions(passthrough); - offset = tflite::CreateLeakyReluOptions(_flatbuffer, options->alpha()).Union(); - } - break; - - case tflite::BuiltinOptions::ShapeOptions: - { - const auto out_type = GetBuiltinOptions(passthrough)->out_type(); - offset = tflite::CreateShapeOptions(_flatbuffer, out_type).Union(); - } - break; - - case tflite::BuiltinOptions::StridedSliceOptions: - { - const auto options = GetBuiltinOptions(passthrough); - const auto typed_offset = tflite::CreateStridedSliceOptions(_flatbuffer, options->begin_mask(), - options->end_mask(), options->ellipsis_mask(), options->new_axis_mask(), options->shrink_axis_mask()); - offset = typed_offset.Union(); - } - break; - - case tflite::BuiltinOptions::SplitOptions: - { - offset = tflite::CreateSplitOptions(_flatbuffer, int(operation->Outputs().size())).Union(); - } - break; - - case tflite::BuiltinOptions::SplitVOptions: - { - offset = tflite::CreateSplitVOptions(_flatbuffer, int(operation->Outputs().size())).Union(); - } - break; - - case tflite::BuiltinOptions::ReducerOptions: - { - const auto options = GetBuiltinOptions(passthrough); - const auto typed_offset = tflite::CreateReducerOptions(_flatbuffer, options->keep_dims()); - offset = typed_offset.Union(); - } - break; - - case tflite::BuiltinOptions::SVDFOptions: - { - assert(passthrough->builtin_options_as_SVDFOptions()); - tflite::ActivationFunctionType fused_activation_function = passthrough->builtin_options_as_SVDFOptions()->fused_activation_function(); - const auto options = GetBuiltinOptions(passthrough); - const auto typed_offset = tflite::CreateSVDFOptions( - _flatbuffer, options->rank(), fused_activation_function, options->asymmetric_quantize_inputs()); - offset = typed_offset.Union(); - } - break; - - case tflite::BuiltinOptions::BatchMatMulOptions: - { - const auto options = GetBuiltinOptions(passthrough); - if ( options ) - { - const auto typed_offset = tflite::CreateBatchMatMulOptions( - _flatbuffer, options->adj_x(), options->adj_y(), options->asymmetric_quantize_inputs()); - offset = typed_offset.Union(); - } - } - break; - case tflite::BuiltinOptions::GatherOptions: - { - const auto options = GetBuiltinOptions(passthrough); - const auto typed_offset = tflite::CreateGatherOptions(_flatbuffer, options->axis(), options->batch_dims()); - offset = typed_offset.Union(); - } - break; - - case tflite::BuiltinOptions::ResizeBilinearOptions: - { - const auto options = GetBuiltinOptions(passthrough); - if ( options ) - { - const auto typed_offset = tflite::CreateResizeBilinearOptions( - _flatbuffer, options->align_corners(), options->half_pixel_centers()); - offset = typed_offset.Union(); - } - else - { - offset = tflite::CreateResizeBilinearOptions(_flatbuffer).Union(); - } - } - break; + return 0; + } - case tflite::BuiltinOptions::CallOnceOptions: - { - const auto options = GetBuiltinOptions(passthrough); - const auto typed_offset = tflite::CreateCallOnceOptions(_flatbuffer, options->init_subgraph_index()); - offset = typed_offset.Union(); - } - break; + const auto *unionTypeTable = tflite::BuiltinOptionsTypeTable(); + assert(unionTypeTable->st == flatbuffers::SequenceType::ST_UNION); + const size_t unionMemberIndex = size_t(unionMemberType) - 1; // 0 is reserved for NONE + assert(unionMemberIndex < unionTypeTable->num_elems); + const auto *unionMemberTypeTable = unionTypeTable->type_refs[unionMemberIndex](); - case tflite::BuiltinOptions::VarHandleOptions: - { - const auto options = GetBuiltinOptions(passthrough); - const auto container = _flatbuffer.CreateString(options->container()); - const auto shared_name = _flatbuffer.CreateString(options->shared_name()); - const auto typed_offset = tflite::CreateVarHandleOptions(_flatbuffer, container, shared_name); - offset = typed_offset.Union(); - } - break; + // The builtin options union member + const auto *unionMember = static_cast(tfliteOperator->builtin_options()); + assert(unionMember); - case tflite::BuiltinOptions::WhileOptions: - { - const auto options = GetBuiltinOptions(passthrough); - if ( options ) - { - const auto typed_offset = tflite::CreateWhileOptions( - _flatbuffer, options->cond_subgraph_index(), options->body_subgraph_index()); - offset = typed_offset.Union(); - } - } - break; - - // Empty option sets can all be written as if they were QuantizeOptions - case tflite::BuiltinOptions::HardSwishOptions: - case tflite::BuiltinOptions::MaximumMinimumOptions: - case tflite::BuiltinOptions::PadOptions: - case tflite::BuiltinOptions::DequantizeOptions: - case tflite::BuiltinOptions::QuantizeOptions: - case tflite::BuiltinOptions::TransposeOptions: - case tflite::BuiltinOptions::GatherNdOptions: - case tflite::BuiltinOptions::ScatterNdOptions: - case tflite::BuiltinOptions::ArgMaxOptions: - case tflite::BuiltinOptions::AssignVariableOptions: - case tflite::BuiltinOptions::ReadVariableOptions: - case tflite::BuiltinOptions::SelectOptions: - case tflite::BuiltinOptions::SelectV2Options: - { - offset = tflite::CreateQuantizeOptions(_flatbuffer).Union(); - } - break; - - case tflite::BuiltinOptions::ConcatEmbeddingsOptions: - case tflite::BuiltinOptions::LSHProjectionOptions: - case tflite::BuiltinOptions::RNNOptions: - case tflite::BuiltinOptions::LocalResponseNormalizationOptions: - case tflite::BuiltinOptions::LSTMOptions: - case tflite::BuiltinOptions::CallOptions: - case tflite::BuiltinOptions::SkipGramOptions: - case tflite::BuiltinOptions::SpaceToDepthOptions: - case tflite::BuiltinOptions::EmbeddingLookupSparseOptions: - case tflite::BuiltinOptions::BatchToSpaceNDOptions: - case tflite::BuiltinOptions::SpaceToBatchNDOptions: - case tflite::BuiltinOptions::SequenceRNNOptions: - case tflite::BuiltinOptions::ExpOptions: - case tflite::BuiltinOptions::TopKV2Options: - case tflite::BuiltinOptions::LogSoftmaxOptions: - case tflite::BuiltinOptions::CastOptions: - case tflite::BuiltinOptions::LessOptions: - case tflite::BuiltinOptions::NegOptions: - case tflite::BuiltinOptions::PadV2Options: - case tflite::BuiltinOptions::GreaterOptions: - case tflite::BuiltinOptions::GreaterEqualOptions: - case tflite::BuiltinOptions::LessEqualOptions: - case tflite::BuiltinOptions::SliceOptions: - case tflite::BuiltinOptions::SparseToDenseOptions: - case tflite::BuiltinOptions::TileOptions: - case tflite::BuiltinOptions::ExpandDimsOptions: - case tflite::BuiltinOptions::EqualOptions: - case tflite::BuiltinOptions::NotEqualOptions: - case tflite::BuiltinOptions::PowOptions: - case tflite::BuiltinOptions::ArgMinOptions: - case tflite::BuiltinOptions::FakeQuantOptions: - case tflite::BuiltinOptions::LogicalOrOptions: - case tflite::BuiltinOptions::OneHotOptions: - case tflite::BuiltinOptions::LogicalAndOptions: - case tflite::BuiltinOptions::LogicalNotOptions: - case tflite::BuiltinOptions::FloorDivOptions: - case tflite::BuiltinOptions::SquareOptions: - case tflite::BuiltinOptions::ZerosLikeOptions: - case tflite::BuiltinOptions::FillOptions: - case tflite::BuiltinOptions::BidirectionalSequenceLSTMOptions: - case tflite::BuiltinOptions::BidirectionalSequenceRNNOptions: - case tflite::BuiltinOptions::UnidirectionalSequenceLSTMOptions: - case tflite::BuiltinOptions::FloorModOptions: - case tflite::BuiltinOptions::RangeOptions: - case tflite::BuiltinOptions::ResizeNearestNeighborOptions: - case tflite::BuiltinOptions::SquaredDifferenceOptions: - case tflite::BuiltinOptions::MirrorPadOptions: - case tflite::BuiltinOptions::AbsOptions: - case tflite::BuiltinOptions::UniqueOptions: - case tflite::BuiltinOptions::ReverseV2Options: - case tflite::BuiltinOptions::AddNOptions: - case tflite::BuiltinOptions::CosOptions: - case tflite::BuiltinOptions::WhereOptions: - case tflite::BuiltinOptions::RankOptions: - case tflite::BuiltinOptions::ReverseSequenceOptions: - case tflite::BuiltinOptions::MatrixDiagOptions: - case tflite::BuiltinOptions::MatrixSetDiagOptions: - case tflite::BuiltinOptions::IfOptions: - case tflite::BuiltinOptions::DepthToSpaceOptions: - case tflite::BuiltinOptions::NonMaxSuppressionV4Options: - case tflite::BuiltinOptions::NonMaxSuppressionV5Options: - case tflite::BuiltinOptions::DensifyOptions: - case tflite::BuiltinOptions::SegmentSumOptions: - case tflite::BuiltinOptions::CumsumOptions: - case tflite::BuiltinOptions::BroadcastToOptions: - case tflite::BuiltinOptions::Rfft2dOptions: - case tflite::BuiltinOptions::Conv3DOptions: - case tflite::BuiltinOptions::HashtableOptions: - case tflite::BuiltinOptions::HashtableFindOptions: - case tflite::BuiltinOptions::HashtableImportOptions: - case tflite::BuiltinOptions::HashtableSizeOptions: - LOG_WARN("TfLiteWriter: Built-in options type '{}' is not yet implemented and will be set to default.\n", - tflite::EnumNameBuiltinOptions(type)); - break; - default: - LOG_ERROR("TfLiteWriter: Unrecognised built-in options type '{}'\n", int(type)); - break; - } - return offset; + return FlatbufferUtils::CopyTable(_flatbuffer, unionMember, unionMemberTypeTable).Union(); } flatbuffers::Offset TfLiteWriter::SerialiseOptions2(const Operation *operation, OpType opType) @@ -906,20 +493,25 @@ flatbuffers::Offset TfLiteWriter::SerialiseOptions2(const Operation *opera return 0; } - flatbuffers::Offset offset = 0; - const tflite::Operator *const passthrough = static_cast(operation->Passthrough()); - assert(passthrough); - const auto type = passthrough->builtin_options_2_type(); - - switch ( type ) + const auto tfliteOperator = static_cast(operation->Passthrough()); + assert(tfliteOperator); + const tflite::BuiltinOptions2 unionMemberType = tfliteOperator->builtin_options_2_type(); + if ( unionMemberType == tflite::BuiltinOptions2::NONE ) { - case tflite::BuiltinOptions2::NONE: - break; - default: - break; + return 0; } - return offset; + const auto *unionTypeTable = tflite::BuiltinOptions2TypeTable(); + assert(unionTypeTable->st == flatbuffers::SequenceType::ST_UNION); + const size_t unionMemberIndex = size_t(unionMemberType) - 1; // 0 is reserved for NONE + assert(unionMemberIndex < unionTypeTable->num_elems); + const auto *unionMemberTypeTable = unionTypeTable->type_refs[unionMemberIndex](); + + // The builtin options union member + const auto *unionMember = static_cast(tfliteOperator->builtin_options_2()); + assert(unionMember); + + return FlatbufferUtils::CopyTable(_flatbuffer, unionMember, unionMemberTypeTable).Union(); } flatbuffers::Offset TfLiteWriter::SerialiseTensorAddresses(int subgraphs) -- GitLab From e652266fd8101abe9d2ef428f264d126bea19208 Mon Sep 17 00:00:00 2001 From: Johan Gunnarsson Date: Tue, 8 Apr 2025 16:35:56 +0200 Subject: [PATCH 2/2] MLBEDSW-10022: Add a unit test for passthrough This tests passthrough using the following steps: 1. Generate a TFLite network as a flatbuffer. 2. Pass the flatbuffer to tflite_reader to obtain the GraphIR. 3. Pass the GraphIR to tflite_writer to generate a flatbuffer again. 4. Compare it with the flatbuffer from step 1. The contents should be identical. Signed-off-by: Johan Gunnarsson Change-Id: I25e6924829845f33aaf24bea665099b819116e67 --- ethosu/regor/test/CMakeLists.txt | 1 + ethosu/regor/test/randomize.hpp | 11 +- ethosu/regor/test/test_passthrough.cpp | 629 +++++++++++++++++++++++++ 3 files changed, 640 insertions(+), 1 deletion(-) create mode 100644 ethosu/regor/test/test_passthrough.cpp diff --git a/ethosu/regor/test/CMakeLists.txt b/ethosu/regor/test/CMakeLists.txt index 0f9391f8..a9cb9fd0 100644 --- a/ethosu/regor/test/CMakeLists.txt +++ b/ethosu/regor/test/CMakeLists.txt @@ -63,6 +63,7 @@ add_catch_test( test_tflite_fb.cpp test_custom_operator_ethosu.cpp test_tflite_supported_operators.cpp + test_passthrough.cpp DEPS test_common ) diff --git a/ethosu/regor/test/randomize.hpp b/ethosu/regor/test/randomize.hpp index b2ccf212..881d9075 100644 --- a/ethosu/regor/test/randomize.hpp +++ b/ethosu/regor/test/randomize.hpp @@ -68,6 +68,14 @@ void randomize(T &value, T min_value = std::numeric_limits::min(), T max_valu } } +// Randomize a real number (with an optional min/max value) +template, int> = 0> +void randomize(T &value, T min_value = std::numeric_limits::min(), T max_value = std::numeric_limits::max()) +{ + std::uniform_real_distribution dist(min_value, max_value); + value = dist(default_rnd_generator); +} + // Usage: // std::vector my_vec; // my_vec.resize(250); @@ -107,7 +115,7 @@ void randomize(std::vector &values) // Create entire randomised vector with bounds template -std::vector random_vector(int length, TYPE min, TYPE max) +std::vector random_vector(int length, TYPE min = std::numeric_limits::min(), TYPE max = std::numeric_limits::max()) { std::uniform_int_distribution distribution(min, max); std::vector temp(length); @@ -143,6 +151,7 @@ T random_of(T first, Ts... rest) T arr[] = {first, rest...}; unsigned index; randomize(index, 0U, unsigned(sizeof...(rest))); + assert(index < 1 + sizeof...(rest)); return arr[index]; } diff --git a/ethosu/regor/test/test_passthrough.cpp b/ethosu/regor/test/test_passthrough.cpp new file mode 100644 index 00000000..0698ead6 --- /dev/null +++ b/ethosu/regor/test/test_passthrough.cpp @@ -0,0 +1,629 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the License); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an AS IS BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "common/common.hpp" +#include "common/logging.hpp" + +#include "randomize.hpp" +#include "tflite/tflite_mapping.hpp" +#include "tflite/tflite_reader.hpp" +#include "tflite/tflite_schema_generated.hpp" +#include "tflite/tflite_supported_operators.hpp" +#include "tflite/tflite_writer.hpp" +#include "util.hpp" + +#include +#include +#include +#include +#include + +using namespace regor; + +namespace +{ +// Generate random scalar integer in the range [1, 127] +template +T GenerateRandomScalar() +{ + T scalar; + randomize(scalar, T(1), T(127)); + return scalar; +} + +// Generate random scalar boolean +template<> +bool GenerateRandomScalar() +{ + bool scalar; + randomize(scalar, false, true); + return scalar; +} +} // namespace + +// Generate a random enum from a TypeTable +template +static E GenerateRandomEnum(const flatbuffers::TypeTable *typeTable) +{ + assert(typeTable->st == flatbuffers::SequenceType::ST_ENUM); + assert(typeTable->num_elems > 0); + + return E(urandom_range(0, typeTable->num_elems - 1u)); +} + +// Generate a random string +static flatbuffers::Offset<> GenerateRandomString(flatbuffers::FlatBufferBuilder &fbb) +{ + return fbb.CreateString("string-" + std::to_string(urandom())).Union(); +} + +// Generate a random enum vector from a TypeTable +template +static flatbuffers::Offset<> +GenerateRandomEnumVector(flatbuffers::FlatBufferBuilder &fbb, const size_t count, const flatbuffers::TypeTable *typeTable) +{ + assert(typeTable->st == flatbuffers::SequenceType::ST_ENUM); + + std::vector v(count); + for ( size_t i = 0; i < count; i++ ) + { + v[i] = GenerateRandomEnum(typeTable); + } + return fbb.CreateVector(v).Union(); +} + +// Generate a random vector from a TypeTable +template +static flatbuffers::Offset<> GenerateRandomScalarVector(flatbuffers::FlatBufferBuilder &fbb, const size_t count) +{ + std::vector v(count); + for ( size_t i = 0; i < count; i++ ) + { + v[i] = GenerateRandomScalar(); + } + return fbb.CreateVector(v).Union(); +} + +// Generate a random string vector +static flatbuffers::Offset<> GenerateRandomStringVector(flatbuffers::FlatBufferBuilder &fbb, const size_t count) +{ + std::vector v(count); + for ( size_t i = 0; i < count; i++ ) + { + v[i] = "string-" + std::to_string(i); + } + return fbb.CreateVectorOfStrings(v).Union(); +} + +// Generate a randomized union member from a TypeTable +static flatbuffers::Offset<> GenerateRandomUnionMember(flatbuffers::FlatBufferBuilder &fbb, const flatbuffers::TypeTable *typeTable) +{ + // This function generates a flatbuffer Table by iterating over the type table two times. First time to generate + // offsets to all non-scalar, repeating types and strings, then a second time to add all offsets and scalars. The + // function is simplified in the sense that it can't generate all possible tables, but instead has enough + // functionality to generate the tables that are used by TFLite's builtin_options (builtin_options2). + + // Can only generate tables + assert(typeTable->st == flatbuffers::SequenceType::ST_TABLE); + + std::unordered_map> fieldToOffset; + + // Iterate over all types and generate offsets for non-scalar items + for ( flatbuffers::voffset_t i = 0; i < typeTable->num_elems; i++ ) + { + const auto name = typeTable->names[i]; + const auto type = flatbuffers::ElementaryType(typeTable->type_codes[i].base_type); + const auto isRepeating = typeTable->type_codes[i].is_repeating != 0; + const auto sequenceRef = typeTable->type_codes[i].sequence_ref; + const auto field = flatbuffers::FieldIndexToOffset(i); + + // Can not generate unions + assert(type != flatbuffers::ET_UTYPE); + + if ( isRepeating ) + { + const size_t count = urandom_range(1, 5); + + if ( sequenceRef >= 0 ) + { + LOG_TRACE1("Generating vector {} (vector of {})\n", name, flatbuffers::ElementaryTypeNames()[type]); + + // Can only generate ENUMs + const auto *sequenceTypeTable = typeTable->type_refs[sequenceRef](); + assert(sequenceTypeTable); + assert(sequenceTypeTable->st == flatbuffers::SequenceType::ST_ENUM); + + if ( type == flatbuffers::ET_CHAR ) + fieldToOffset[field] = GenerateRandomEnumVector(fbb, count, sequenceTypeTable); + else if ( type == flatbuffers::ET_INT ) + fieldToOffset[field] = GenerateRandomEnumVector(fbb, count, sequenceTypeTable); + else if ( type == flatbuffers::ET_UINT ) + fieldToOffset[field] = GenerateRandomEnumVector(fbb, count, sequenceTypeTable); + else + // Unsupported type (probably ET_SEQUENCE) + assert(false && "Unsupported elementary type"); + } + else + { + LOG_TRACE1("Generating vector {} (vector of {})\n", name, flatbuffers::ElementaryTypeNames()[type]); + + // Vector of scalars/strings + if ( type == flatbuffers::ET_BOOL ) fieldToOffset[field] = GenerateRandomScalarVector(fbb, count); + else if ( type == flatbuffers::ET_CHAR ) + fieldToOffset[field] = GenerateRandomScalarVector(fbb, count); + else if ( type == flatbuffers::ET_UCHAR ) + fieldToOffset[field] = GenerateRandomScalarVector(fbb, count); + else if ( type == flatbuffers::ET_SHORT ) + fieldToOffset[field] = GenerateRandomScalarVector(fbb, count); + else if ( type == flatbuffers::ET_USHORT ) + fieldToOffset[field] = GenerateRandomScalarVector(fbb, count); + else if ( type == flatbuffers::ET_INT ) + fieldToOffset[field] = GenerateRandomScalarVector(fbb, count); + else if ( type == flatbuffers::ET_UINT ) + fieldToOffset[field] = GenerateRandomScalarVector(fbb, count); + else if ( type == flatbuffers::ET_LONG ) + fieldToOffset[field] = GenerateRandomScalarVector(fbb, count); + else if ( type == flatbuffers::ET_ULONG ) + fieldToOffset[field] = GenerateRandomScalarVector(fbb, count); + else if ( type == flatbuffers::ET_FLOAT ) + fieldToOffset[field] = GenerateRandomScalarVector(fbb, count); + else if ( type == flatbuffers::ET_DOUBLE ) + fieldToOffset[field] = GenerateRandomScalarVector(fbb, count); + else if ( type == flatbuffers::ET_STRING ) + fieldToOffset[field] = GenerateRandomStringVector(fbb, count); + else assert(false && "Unsupported elementary type"); + } + } + else + { + LOG_TRACE1("Generating string {} ({})\n", name, flatbuffers::ElementaryTypeNames()[type]); + + if ( type == flatbuffers::ET_STRING ) fieldToOffset[field] = GenerateRandomString(fbb); + } + } + + const auto tableOffset = fbb.StartTable(); + + // Iterate over all types and add offsets and scalar types + for ( flatbuffers::voffset_t i = 0; i < typeTable->num_elems; i++ ) + { + const auto name = typeTable->names[i]; + const auto type = flatbuffers::ElementaryType(typeTable->type_codes[i].base_type); + const auto isRepeating = typeTable->type_codes[i].is_repeating != 0; + const auto sequenceRef = typeTable->type_codes[i].sequence_ref; + const auto field = flatbuffers::FieldIndexToOffset(i); + + if ( fieldToOffset.count(field) == 0 ) + { + // At this point it's too late for repeating types + assert(!isRepeating); + + if ( sequenceRef >= 0 ) + { + LOG_TRACE1("Generating and adding sequence {} ({})\n", name, flatbuffers::ElementaryTypeNames()[type]); + + const auto *sequenceTypeTable = typeTable->type_refs[sequenceRef](); + assert(sequenceTypeTable); + assert(sequenceTypeTable->st == flatbuffers::SequenceType::ST_ENUM); + + if ( type == flatbuffers::ET_CHAR ) + fbb.AddElement(field, GenerateRandomEnum(sequenceTypeTable)); + else if ( type == flatbuffers::ET_INT ) + fbb.AddElement(field, GenerateRandomEnum(sequenceTypeTable)); + else if ( type == flatbuffers::ET_UINT ) + fbb.AddElement(field, GenerateRandomEnum(sequenceTypeTable)); + else + // Unsupported type (probably ET_SEQUENCE) + assert(false && "Unsupported elementary type"); + } + else + { + LOG_TRACE1("Generating and adding scalar {} ({})\n", name, flatbuffers::ElementaryTypeNames()[type]); + + // Scalar + if ( type == flatbuffers::ET_BOOL ) fbb.AddElement(field, GenerateRandomScalar()); + else if ( type == flatbuffers::ET_CHAR ) fbb.AddElement(field, GenerateRandomScalar()); + else if ( type == flatbuffers::ET_UCHAR ) fbb.AddElement(field, GenerateRandomScalar()); + else if ( type == flatbuffers::ET_SHORT ) fbb.AddElement(field, GenerateRandomScalar()); + else if ( type == flatbuffers::ET_USHORT ) fbb.AddElement(field, GenerateRandomScalar()); + else if ( type == flatbuffers::ET_INT ) fbb.AddElement(field, GenerateRandomScalar()); + else if ( type == flatbuffers::ET_UINT ) fbb.AddElement(field, GenerateRandomScalar()); + else if ( type == flatbuffers::ET_LONG ) fbb.AddElement(field, GenerateRandomScalar()); + else if ( type == flatbuffers::ET_ULONG ) fbb.AddElement(field, GenerateRandomScalar()); + else if ( type == flatbuffers::ET_FLOAT ) fbb.AddElement(field, GenerateRandomScalar()); + else if ( type == flatbuffers::ET_DOUBLE ) fbb.AddElement(field, GenerateRandomScalar()); + else assert(false && "Unsupported elementary type"); + } + } + else + { + LOG_TRACE1("Adding offset {} ({})\n", name, flatbuffers::ElementaryTypeNames()[type]); + + fbb.AddOffset(field, fieldToOffset[field]); + } + } + + return fbb.EndTable(tableOffset); +} + +// Compare two lists of strings +static void CompareItems(const std::vector &actual, const std::vector &expected) +{ + static const std::array ignoredPrefixes = { + "/model/description", + }; + + // Log all items + for ( auto &item : actual ) + LOG_TRACE2("Actual: {}\n", item); + for ( auto &item : expected ) + LOG_TRACE2("Expected: {}\n", item); + + // Check each item + auto actualIter = actual.begin(); + auto expectedIter = expected.begin(); + while ( actualIter != actual.end() && expectedIter != expected.end() ) + { + bool ignoreActual = false; + bool ignoreExpected = false; + for ( auto &prefix : ignoredPrefixes ) + { + ignoreActual = ignoreActual || (actualIter->rfind(prefix, 0) != std::string::npos); + ignoreExpected = ignoreExpected || (expectedIter->rfind(prefix, 0) != std::string::npos); + } + if ( ignoreActual ) + { + LOG_TRACE1("Ignoring actual item: {}\n", *actualIter); + actualIter++; + } + if ( ignoreExpected ) + { + LOG_TRACE1("Ignoring expected item: {}\n", *expectedIter); + expectedIter++; + } + if ( !ignoreActual && !ignoreExpected ) + { + auto &a = *actualIter; + auto &b = *expectedIter; + REQUIRE(a == b); + actualIter++; + expectedIter++; + } + } + + // Make sure we have consumed all items + const bool consumedAllActual = (actualIter == actual.end()); + const bool consumedAllExpected = (expectedIter == expected.end()); + + // Log all items left + for ( ; actualIter != actual.end(); actualIter++ ) + LOG_TRACE2("Unconsumed actual: {}\n", *actualIter); + for ( ; expectedIter != expected.end(); expectedIter++ ) + LOG_TRACE2("Unconsumed expected: {}\n", *expectedIter); + + REQUIRE(consumedAllActual); + REQUIRE(consumedAllExpected); +} + +// Visitor that converts a flatbuffer to a vector of strings, one string per elementary item +class ToItemListVisitor : public flatbuffers::IterationVisitor +{ +public: + void StartSequence() { sequences.emplace_back(); } + void EndSequence() { sequences.pop_back(); } + void StartVector() { vectors.push_back(currentField); } + void EndVector() { vectors.pop_back(); } + + void Field(size_t, size_t, flatbuffers::ElementaryType type, bool, const flatbuffers::TypeTable *, const char *name, const uint8_t *) + { + assert(name); + + currentField = name; + currentFieldType = flatbuffers::ElementaryTypeNames()[type]; + sequences.back() = currentField; + } + + void Element(size_t i, flatbuffers::ElementaryType type, const flatbuffers::TypeTable *, const uint8_t *) + { + currentFieldType = flatbuffers::ElementaryTypeNames()[type]; + sequences.back() = vectors.back() + "[" + std::to_string(i) + "]"; + } + + template + void HandleItem(T value, const char *name) + { + // Build the item path + auto item = std::accumulate(sequences.begin(), sequences.end(), std::string(), + [](const std::string &a, const std::string &b) { return a + "/" + b; }); + + // Append item value (and type if available) + if ( name ) item += fmt::format(" = {}", name); + else item += fmt::format(" = {} ({})", value, currentFieldType); + + items.push_back(std::move(item)); + } + + void UType(uint8_t x, const char *name) { HandleItem(x, name); } + void Bool(bool x) { HandleItem(x, nullptr); } + void Char(int8_t x, const char *name) { HandleItem(x, name); } + void UChar(uint8_t x, const char *name) { HandleItem(x, name); } + void Short(int16_t x, const char *name) { HandleItem(x, name); } + void UShort(uint16_t x, const char *name) { HandleItem(x, name); } + void Int(int32_t x, const char *name) { HandleItem(x, name); } + void UInt(uint32_t x, const char *name) { HandleItem(x, name); } + void Long(int64_t x) { HandleItem(x, nullptr); } + void ULong(uint64_t x) { HandleItem(x, nullptr); } + void Float(float x) { HandleItem(x, nullptr); } + void Double(double x) { HandleItem(x, nullptr); } + void String(const flatbuffers::String *str) { HandleItem(str->string_view(), nullptr); } + void Unknown(const uint8_t *) { HandleItem("?", "UNKNOWN"); } + + std::vector items; + +private: + std::deque sequences{"model"}; + std::deque vectors; + std::string currentField; + std::string currentFieldType; +}; + +// Mark one op as passthrough and remove any associated activation function +static void MarkAsPassthrough(Operation *op) +{ + if ( TfLiteMapping::CanFuseActivationFunction(op) ) + { + const auto ofm = op->OFM(); + assert(ofm); + assert(ofm->IsSinglePath()); + const auto activation = ofm->Readers().front(); + const auto actOfmConn = activation->Output(TensorUsage::OFM); + assert(actOfmConn); + + // Bypass and remove activation op + op->CopyOutput(TensorUsage::OFM, *actOfmConn); + activation->SetPassthroughOp(); + activation->Disconnect(); + } + + op->SetPassthroughOp(); +} + +TEST_CASE("passthrough") +{ + // This tests the passthrough functionality of Regor. Passthrough refers to outputting an unsupported operator + // unchanged so that it can be executed by the TFLite or TFLite Micro framework on the CPU. + // + // This test ensures that functionality as follows: + // + // 1. Generate a TFLite network as a flatbuffer. + // 2. Pass the flatbuffer to tflite_reader to obtain the GraphIR. + // 3. Pass the GraphIR to tflite_writer to generate a flatbuffer again. + // 4. Compare it with the flatbuffer from step 1. The contents should be identical. + // + // This is done for all operators in TFLite (together with their BuiltinOptions or BuiltinOptions2). + + DisableLogging disableLogging; + + // Use from_range to generate values from the array + const tflite::BuiltinOperator op = GENERATE( + from_range(std::begin(tflite::EnumValuesBuiltinOperator()), std::end(tflite::EnumValuesBuiltinOperator()))); + LOG_TRACE1("Testing operator {}\n", tflite::EnumNameBuiltinOperator(op)); + + flatbuffers::FlatBufferBuilder fbb; + + // Per model + std::vector> operatorCodes; + std::vector> subgraphs; + std::vector> buffers; + + // Per subgraph + std::vector> operations; + std::vector> tensors; + std::vector> metadata; + + { + // Generate 1 operator code + const int8_t deprecatedBuiltinCode = std::min(int8_t(op), 127); + const char *customCode = nullptr; + const int32_t version = urandom_range(1, 5); + const tflite::BuiltinOperator builtinCode = op; + operatorCodes.push_back(tflite::CreateOperatorCodeDirect(fbb, deprecatedBuiltinCode, customCode, version, builtinCode)); + } + + // Generate empty first buffer + buffers.push_back(tflite::CreateBufferDirect(fbb)); // Buffer 0 + + { + // Generate tensor + const std::vector shape = {1, 9, 9, 3}; + const tflite::TensorType type = tflite::TensorType::INT8; + const int bufferIndex = 0; + const std::string name = "ifm0"; + + // Create QuantizationParameters + const std::vector min = random_vector(3, 0, 127); + const std::vector max = random_vector(3, 128, 255); + const std::vector scale = {0.0042}; + const std::vector zeroPoint = {3, 7, 11}; + const tflite::QuantizationDetails detailsType = tflite::QuantizationDetails::CustomQuantization; + const std::vector custom = random_vector(4, 0, 255); + const auto details = tflite::CreateCustomQuantizationDirect(fbb, &custom).Union(); + const int32_t quantizedDimension = 3; + const auto quantization = tflite::CreateQuantizationParametersDirect( + fbb, &min, &max, &scale, &zeroPoint, detailsType, details, quantizedDimension); + + const bool isVariable = random_of(true, false); + + // Create SparsityParameters + const std::vector traversalOrder = random_vector(4); + const std::vector blockMap = random_vector(4); + const tflite::DimensionType format = tflite::DimensionType::DENSE; + const int32_t denseSize = urandom(); + const tflite::SparseIndexVector arraySegmentsType = tflite::SparseIndexVector::Uint16Vector; + std::vector arraySegmentsValues = random_vector(4); + const auto arraySegments = tflite::CreateUint16VectorDirect(fbb, &arraySegmentsValues).Union(); + const tflite::SparseIndexVector arrayIndicesType = tflite::SparseIndexVector::Uint16Vector; + std::vector arrayIndicesValues = random_vector(4); + const auto arrayIndices = tflite::CreateUint16VectorDirect(fbb, &arrayIndicesValues).Union(); + const std::vector> dimMetadata = { + tflite::CreateDimensionMetadata(fbb, format, denseSize, arraySegmentsType, arraySegments, arrayIndicesType, arrayIndices), + }; + const auto sparsity = tflite::CreateSparsityParametersDirect(fbb, &traversalOrder, &blockMap, &dimMetadata); + + const std::vector shapeSignature = {1, 9, 9, 3}; + const bool hasRank = random_of(true, false); + + // Create VariantSubType + const std::vector shape1 = random_vector(4, 1, 32); + const tflite::TensorType type1 = random_of(tflite::TensorType::INT8, tflite::TensorType::INT16, tflite::TensorType::INT32); + const bool hasRank1 = random_of(true, false); + const std::vector> variant_tensors = { + tflite::CreateVariantSubTypeDirect(fbb, &shape1, type1, hasRank1), + }; + + tensors.push_back(tflite::CreateTensorDirect(fbb, &shape, type, bufferIndex, name.c_str(), quantization, + isVariable, sparsity, &shapeSignature, hasRank, &variant_tensors)); + } + + for ( auto &index : {1, 2, 3} ) + { + // Generate buffer with data + std::vector data = random_vector(9, 0, 255); + buffers.push_back(tflite::CreateBufferDirect(fbb, &data)); + + // Generate simple constant tensor + const std::vector shape = {1, 1, 3, 3}; + const tflite::TensorType type = tflite::TensorType::INT16; + const int bufferIndex = buffers.size() - 1; + const std::string name = "const-" + std::to_string(index); + tensors.push_back(tflite::CreateTensorDirect(fbb, &shape, type, bufferIndex, name.c_str())); + } + + { + // Generate simple output tensor + const std::vector shape = {1, 11, 11, 3}; + const tflite::TensorType type = tflite::TensorType::FLOAT32; + const int bufferIndex = 0; + const std::string name = "ofm"; + tensors.push_back(tflite::CreateTensorDirect(fbb, &shape, type, bufferIndex, name.c_str())); + } + + /* + // TODO: MLBEDSW-9946: Metadata passthrough is currently not supported + { + std::vector data2 = random_vector(5, 0, 255); + serialised_buffers.push_back(tflite::CreateBufferDirect(fbb, &data2)); + serialised_metadata.push_back(tflite::CreateMetadataDirect(fbb, "metadata1", serialised_buffers.size() - 1)); + } + */ + + { + // Generate 1 operator + const uint32_t opcodeIndex = 0; + const std::vector inputs = {0, 1, 2, 3}; + const std::vector outputs = {4}; + const std::vector customOptions = random_vector(5); + const std::vector mutatingVariableInputs = random_vector(4); + const std::vector intermediates = random_vector(4); + + // Generate builtin_options or builtin_options2 + flatbuffers::Offset<> builtinOptions = 0; + tflite::BuiltinOptions builtinOptionsType = TfLiteMapping::BuiltinOperatorToBuiltinOptions(op); + flatbuffers::Offset<> builtinOptions2 = 0; + tflite::BuiltinOptions2 builtinOptions2Type = TfLiteMapping::BuiltinOperatorToBuiltinOptions2(op); + if ( builtinOptionsType != tflite::BuiltinOptions::NONE ) + { + LOG_TRACE1("Generating union {}\n", tflite::EnumNameBuiltinOptions(builtinOptionsType)); + builtinOptions = GenerateRandomUnionMember( + fbb, tflite::BuiltinOptionsTypeTable()->type_refs[size_t(builtinOptionsType) - 1]()); + } + else if ( builtinOptions2Type != tflite::BuiltinOptions2::NONE ) + { + LOG_TRACE1("Generating union {}\n", tflite::EnumNameBuiltinOptions2(builtinOptions2Type)); + builtinOptions2 = GenerateRandomUnionMember( + fbb, tflite::BuiltinOptions2TypeTable()->type_refs[size_t(builtinOptions2Type) - 1]()); + } + + operations.push_back(tflite::CreateOperatorDirect(fbb, opcodeIndex, &inputs, &outputs, builtinOptionsType, + builtinOptions, &customOptions, tflite::CustomOptionsFormat::FLEXBUFFERS, &mutatingVariableInputs, + &intermediates, 0, 0, builtinOptions2Type, builtinOptions2)); + } + + { + // Generate 1 subgraph + const std::vector inputs = {0 /* ifm0 */}; + const std::vector outputs = {4 /* ofm */}; + const char *name = "subgraph1"; + subgraphs.push_back(tflite::CreateSubGraphDirect(fbb, &tensors, &inputs, &outputs, &operations, name)); + } + + // TODO: add metadata_buffer, metadata and signature_defs + // Generate 1 model + const auto model1 = tflite::CreateModelDirect( + fbb, 3 /* Version */, &operatorCodes, &subgraphs, "description1", &buffers, nullptr, &metadata); + + // Create TFLite flatbuffer + tflite::FinishModelBuffer(fbb, model1); + const auto bufExpected = fbb.Release(); + LOG_TRACE1("Created network ({}, size {})\n", fmt::ptr(bufExpected.data()), bufExpected.size()); + + // Read TFLite network as GraphIR + std::vector> graphs1; + TfLiteReader::LoadGraphs(bufExpected.data(), bufExpected.size(), graphs1, nullptr, true); + LOG_TRACE1("Read network ({} subgraphs)\n", graphs1.size()); + + // Check GraphIR + REQUIRE(graphs1.size() == 1); + REQUIRE(graphs1[0]->Inputs().size() == 1); + REQUIRE(graphs1[0]->Outputs().size() == 1); + std::vector ops1; + graphs1[0]->GetAllOperations(ops1); + REQUIRE((ops1.size() == 1 || ops1.size() == 2)); + MarkAsPassthrough(ops1[0]); + ops1.clear(); + graphs1[0]->GetAllOperations(ops1); + REQUIRE(ops1.size() == 1); + REQUIRE(ops1[0]->Type() == OpType::Passthrough); + + // Write GraphIR as TFLite flatbuffer + std::vector> maps{{}}; + int64_t offset = 0; + size_t size = 0; + TfLiteWriter writer(1 << 31, true /* skip OfflineMemoryAllocation */); + graphs1[0]->SetScheduledOrder(std::move(ops1)); + const auto bufActual = writer.Serialise(graphs1, maps, offset, size); + LOG_TRACE1("Wrote network ({}, offset {}, size {})\n", fmt::ptr(bufActual.get()), offset, size); + + // Parse actual TFLite flatbuffer to a list of items + flatbuffers::Verifier::Options options; + flatbuffers::Verifier verifier1(bufActual.get() + offset, size, options); + REQUIRE(tflite::VerifyModelBuffer(verifier1)); + auto modelActual = tflite::GetModel(bufActual.get() + offset); + ToItemListVisitor toItemListVisitor1; + IterateFlatBuffer(bufActual.get() + offset, modelActual->MiniReflectTypeTable(), &toItemListVisitor1); + + // Parse expected TFLite flatbuffer to a list of items + flatbuffers::Verifier verifier2(bufExpected.data(), bufExpected.size(), options); + REQUIRE(tflite::VerifyModelBuffer(verifier2)); + auto modelExpected = tflite::GetModel(bufExpected.data()); + ToItemListVisitor toItemListVisitor2; + IterateFlatBuffer(bufExpected.data(), modelExpected->MiniReflectTypeTable(), &toItemListVisitor2); + + // Comapare actual and expected flatbuffer + CompareItems(toItemListVisitor1.items, toItemListVisitor2.items); +} -- GitLab