From 0a447aff5a033d756f6f58528c12411c3f4d481f Mon Sep 17 00:00:00 2001 From: Johan Gunnarsson Date: Thu, 5 Dec 2024 16:42:03 +0100 Subject: [PATCH] MLBEDSW-9689: Lower TFLite StridedSlice into TOSA Slice * Add support for >3D input. * Add support for int32 input. * Add support for the ellipsis attribute. * Stride parameter tensor is still not supported and marked as passthrough. Signed-off-by: Johan Gunnarsson Change-Id: Iad3ff021ac921f68526a086a42313d9136b9893b --- .../regor/compiler/tflite_graph_optimiser.cpp | 320 +++++++----------- .../regor/compiler/tflite_graph_optimiser.hpp | 19 +- 2 files changed, 124 insertions(+), 215 deletions(-) diff --git a/ethosu/regor/compiler/tflite_graph_optimiser.cpp b/ethosu/regor/compiler/tflite_graph_optimiser.cpp index 9f646cda..5ca62807 100644 --- a/ethosu/regor/compiler/tflite_graph_optimiser.cpp +++ b/ethosu/regor/compiler/tflite_graph_optimiser.cpp @@ -250,50 +250,6 @@ int TFLiteGraphOptimiser::GetAxis(const Operation *const operation) } -// Calculate the read shape and offset values for StridedSlice. -void TFLiteGraphOptimiser::SetStridedSliceOffsetValues( - Operation *const operation, const TensorConnection *const ifmConn, Shape &readShape, Shape &readOffset) -{ - auto *beginConn = operation->Input(TensorUsage::Params0); - auto *endConn = operation->Input(TensorUsage::Params1); - - const tflite::Operator *passthrough = static_cast(operation->Passthrough()); - assert(passthrough); - auto *opt = passthrough->builtin_options_as_StridedSliceOptions(); - assert(opt); - - // strides tensor not used. - auto beginMask = opt->begin_mask(); - auto endMask = opt->end_mask(); - - readShape = ifmConn->shape; - - for ( auto idx = 0; idx < ifmConn->shape.Size(); idx++ ) - { - // If the i:th bit in the mask is set then the value on offset_tens[i] should be ignored - if ( (beginMask & (1 << idx)) == 0 ) - { - readOffset[idx] = beginConn->tensor->View().Values()[idx]; - if ( readOffset[idx] < 0 ) - { - // Convert offset to positive value - readOffset[idx] += ifmConn->shape[idx]; - } - } - if ( (endMask & (1 << idx)) == 0 ) - { - readShape[idx] = endConn->tensor->View().Values()[idx]; - if ( readShape[idx] < 0 ) - { - // Convert offset to positive value - readShape[idx] += ifmConn->shape[idx]; - } - } - } - readOffset = Shape::PadAxes(readOffset, 4, 0); -} - - // Creates MemoryCopy operation for the given ifm/ofm and write offset. std::shared_ptr TFLiteGraphOptimiser::MakeMemoryCopyForConcat( const TensorConnection *const ofmConn, const TensorConnection *const ifmConn, const Shape &writeOffset) @@ -311,73 +267,6 @@ std::shared_ptr TFLiteGraphOptimiser::MakeMemoryCopyForConcat( } -// Creates a MemoryCopy operation for the given ifm/ofm and readOffset. -std::shared_ptr TFLiteGraphOptimiser::MakeMemoryCopyForSplitOps(const TensorConnection *const ofmConn, - const TensorConnection *const ifmConn, const Shape &readShape, const Shape &readOffset) -{ - auto op = std::make_shared(OpType::MemoryCopy); - op->SetRounding(RoundMode::NATURAL); - op->ConnectInput(TensorUsage::IFM0, ifmConn->tensor).Set(ifmConn->shape).Set(ifmConn->quantization).Set({readOffset, readShape}); - op->CopyOutput(TensorUsage::OFM, *ofmConn); - - return op; -} - - -// Creates the desired Output shape of StridedSlice. -// -// returns the Desired shape. -Shape TFLiteGraphOptimiser::MakeStridedSliceDesiredShape(Operation *const operation, const Shape &baseShape) -{ - const tflite::Operator *passthrough = static_cast(operation->Passthrough()); - assert(passthrough); - auto *opt = passthrough->builtin_options_as_StridedSliceOptions(); - assert(opt); - unsigned newMask = unsigned(opt->new_axis_mask()); - unsigned shrinkMask = unsigned(opt->shrink_axis_mask()); - - if ( newMask == 0 && shrinkMask == 0 ) - { - return baseShape; - } - assert((newMask == 0) || (shrinkMask == 0)); - - Shape tmp = baseShape; - while ( shrinkMask ) - { - auto prevMask = shrinkMask; - shrinkMask &= shrinkMask - 1; - auto axis = 0; - auto diff = prevMask - shrinkMask; - diff >>= 1; - while ( diff ) - { - diff >>= 1; - ++axis; - } - tmp = tmp.Insert(axis, 1); - } - - while ( newMask ) - { - auto prevMask = newMask; - newMask &= newMask - 1; - auto axis = 0; - auto diff = prevMask - newMask; - diff >>= 1; - while ( diff ) - { - diff >>= 1; - ++axis; - } - tmp = tmp.Erase(axis); - newMask >>= 1; - } - - return Shape::PadAxes(tmp, 4, 1); -} - - Operation *TFLiteGraphOptimiser::MakeDepthwiseMeanOp(const TensorConnection *ifmConn, const Shape &ifmShape4D, const Shape &readShape, const Shape &readOffset, const Shape &ofmShape4D, int w, int h, const std::string &name, std::shared_ptr &weightTensor, std::shared_ptr biasTensor, const Quantization &ifmQuant, const Quantization &weightQuant, const Quantization &ofmQuant) @@ -752,6 +641,128 @@ Operation *TFLiteGraphOptimiser::RewriteSlice(Graph *const graph, Operation *con return returnOp; } + +// Convert TFLite StridedSlice into TOSA Slice +Operation *TFLiteGraphOptimiser::RewriteStridedSlice(Graph *const graph, Operation *const operation) +{ + UNUSED(graph); + auto *returnOp = operation; + const auto opType = operation->Type(); + if ( opType == OpType::StridedSlice ) + { + const auto *ifmConn = operation->Input(TensorUsage::IFM); + const auto *ofmConn = operation->Output(TensorUsage::OFM); + const auto *beginParmConn = operation->Input(TensorUsage::Params0); + const auto *endParamConn = operation->Input(TensorUsage::Params1); + const auto *stridesParamConn = operation->Input(TensorUsage::Params2); + + // Read StridedSlice attributes + int32_t begin_mask = 0; + int32_t ellipsis_mask = 0; + int32_t end_mask = 0; + int32_t new_axis_mask = 0; + int32_t shrink_axis_mask = 0; + const tflite::Operator *const passthrough = static_cast(operation->Passthrough()); + if ( passthrough ) + { + const auto options = passthrough->builtin_options_as_StridedSliceOptions(); + if ( options ) + { + begin_mask = options->begin_mask(); + ellipsis_mask = options->ellipsis_mask(); + end_mask = options->end_mask(); + new_axis_mask = options->new_axis_mask(); + shrink_axis_mask = options->shrink_axis_mask(); + } + } + + const Shape beginAttr = TensorToShape(beginParmConn->tensor.get(), beginParmConn->shape.Elements()); + const Shape endAttr = TensorToShape(endParamConn->tensor.get(), endParamConn->shape.Elements()); + const Shape stridesAttr = TensorToShape(stridesParamConn->tensor.get(), stridesParamConn->shape.Elements()); + const int specShapeSize = std::min({beginAttr.Size(), endAttr.Size(), stridesAttr.Size()}); + + // Start off with the full IFM + const int ifmShapeSize = ifmConn->shape.Size(); + Shape sliceOffset(nullptr, ifmShapeSize, 0); + Shape sliceShape(ifmConn->shape); + Shape sliceStride(nullptr, ifmShapeSize, 1); + + // Process each spec + for ( int specIndex = 0, ifmIndex = 0; specIndex < specShapeSize; specIndex++ ) + { + const bool isBegin = (begin_mask & (1 << specIndex)) != 0; + const bool isEllipsis = (ellipsis_mask & (1 << specIndex)) != 0; + const bool isEnd = (end_mask & (1 << specIndex)) != 0; + const bool isNewAxis = (new_axis_mask & (1 << specIndex)) != 0; + const bool isShrink = (shrink_axis_mask & (1 << specIndex)) != 0; + + if ( isEllipsis ) + { + // Skip to the end + ifmIndex = ifmShapeSize - (specShapeSize - specIndex - 1); + assert(ifmIndex >= 0); + assert(ifmIndex <= ifmShapeSize); + } + else + { + if ( !isBegin || isShrink ) + { + // Handle the begin value + int begin = beginAttr[specIndex]; + if ( begin < 0 ) begin = ifmConn->shape[ifmIndex] + begin; + begin = std::clamp(begin, 0, ifmConn->shape[ifmIndex] - 1); + sliceOffset[ifmIndex] = begin; + sliceShape[ifmIndex] = isShrink ? 1 : ifmConn->shape[ifmIndex] - begin; + } + + if ( !isEnd && !isShrink ) + { + // Handle the end value + int end = endAttr[specIndex]; + if ( end < 0 ) end = ifmConn->shape[ifmIndex] + end; + end = std::clamp(end, 1, ifmConn->shape[ifmIndex]); + assert(end > sliceOffset[ifmIndex]); + sliceShape[ifmIndex] = end - sliceOffset[ifmIndex]; + } + + // Handle the stride value + sliceStride[ifmIndex] = stridesAttr[specIndex]; + + // Go to next dimension + ifmIndex++; + } + } + + // TODO MLBEDSW-10165: Handle stride != 1 + if ( sliceStride != sliceStride.WithOnes() ) + { + returnOp->SetPassthroughOp(); + return returnOp; + } + + // Adjust resulting shape for stride + sliceShape = Shape::DivRoundUp(sliceShape, sliceStride); + + // Create a new SLICE op + auto sliceOp = std::make_shared(OpType::Slice); + sliceOp->CopyInput(TensorUsage::IFM, *ifmConn); + sliceOp->CopyOutput(TensorUsage::OFM, *ofmConn); + sliceOp->Output(TensorUsage::OFM)->Set(sliceShape); + auto *attr = sliceOp->Attribute(); + assert(sliceOffset + sliceShape <= ifmConn->shape); + assert(sliceOffset >= ifmConn->shape.WithZeros()); + attr->size = sliceShape; + attr->begin = sliceOffset; + RecordOptimisation(operation, sliceOp.get()); + returnOp = sliceOp.get(); + + // Remove original op + operation->Disconnect(); + } + return returnOp; +} + + // Convert TFLite Unpack/Split/SplitV into one or more TOSA Slice Operation *TFLiteGraphOptimiser::RewriteUnpack(Graph *const graph, Operation *const operation) { @@ -799,93 +810,6 @@ Operation *TFLiteGraphOptimiser::RewriteUnpack(Graph *const graph, Operation *co return returnOp; } -Operation *TFLiteGraphOptimiser::RewriteSplit(Graph *const graph, Operation *const operation) -{ - UNUSED(graph); - auto *returnOp = operation; - auto opType = operation->Type(); - - if ( opType == OpType::StridedSlice ) - { - auto *ifmConn = operation->Input(TensorUsage::IFM0); - assert(ifmConn); - auto *ofmConn = operation->Output(TensorUsage::OFM); - assert(ofmConn); - auto axis = GetAxis(operation); - auto axis4D = 0; - - if ( opType == OpType::StridedSlice ) - { - const tflite::Operator *passthrough = static_cast(operation->Passthrough()); - assert(passthrough); - auto *opt = passthrough->builtin_options_as_StridedSliceOptions(); - assert(opt); - // StridedSlice ellipsis_mask not supported. - // StridedSlice new_axis_mask and shrink_axis_mask cannot both be set. - const auto ellipsis_mask = opt->ellipsis_mask(); - const auto new_axis_mask = opt->new_axis_mask(); - const auto shrink_axis_mask = opt->shrink_axis_mask(); - if ( ellipsis_mask != 0 || (new_axis_mask != 0 && shrink_axis_mask != 0) ) - { - returnOp->SetPassthroughOp(); - return returnOp; - } - } - - // Only rewrite for int8, uint8 and int16 supported. - auto ifmType = ifmConn->tensor->Type(); - if ( ifmType != DataType::Int8 && ifmType != DataType::UInt8 && ifmType != DataType::Int16 ) - { - returnOp->SetPassthroughOp(); - return returnOp; - } - - // Only rewrite for int8, uint8 and int16 supported. - auto ofmType = ofmConn->tensor->Type(); - if ( ofmType != DataType::Int8 && ofmType != DataType::UInt8 && ofmType != DataType::Int16 ) - { - returnOp->SetPassthroughOp(); - return returnOp; - } - - auto idx = 0; - auto offset = 0; - auto usage = MakeTensorUsage(TensorUsage::OFM, 0); - ofmConn = operation->Output(usage); - // Set shape on all OFMs - while ( ofmConn != nullptr ) - { - // Remove writers from OFM - auto *ofm = ofmConn->tensor.get(); - ofm->RemoveWriters(); - - Shape readOffset(0, 0, 0, 0); - Shape readShape(1, 1, 1, 1); - - if ( opType == OpType::StridedSlice ) - { - // TODO: MLBEDSW-9071: Change StridedSlice shape to 4D - ofmConn->shape = MakeStridedSliceDesiredShape(operation, ofmConn->shape); - readShape = ifmConn->shape.WithOnes(); - readOffset = ifmConn->shape.WithZeros(); - SetStridedSliceOffsetValues(operation, ifmConn, readShape, readOffset); - } - - auto op = MakeMemoryCopyForSplitOps(ofmConn, ifmConn, readShape, readOffset); - offset += ofmConn->shape[axis4D]; - - usage = MakeTensorUsage(TensorUsage::OFM, ++idx); - ofmConn = operation->Output(usage); - RecordOptimisation(operation, op.get()); - } - // Replaced by multiple ops. - // Will return the original op, which have all the Input/Outputs for the traversal. - // But with Writers and Readers cleared. - ifmConn->tensor->RemoveReader(operation->shared_from_this()); - } - return returnOp; -} - Operation *TFLiteGraphOptimiser::RemoveReshape(Graph *const graph, Operation *const operation) { diff --git a/ethosu/regor/compiler/tflite_graph_optimiser.hpp b/ethosu/regor/compiler/tflite_graph_optimiser.hpp index ef453580..d87c32bd 100644 --- a/ethosu/regor/compiler/tflite_graph_optimiser.hpp +++ b/ethosu/regor/compiler/tflite_graph_optimiser.hpp @@ -73,19 +73,9 @@ private: // Get axis parameter for operator int GetAxis(const Operation *const operation); - // Calculate the read shape and offset values for StridedSlice. - void SetStridedSliceOffsetValues(Operation *const operation, const TensorConnection *const ifmConn, Shape &readShape, Shape &readOffset); // Creates MemoryCopy operation for the given ifm/ofm and write offset. std::shared_ptr MakeMemoryCopyForConcat( const TensorConnection *const ofmConn, const TensorConnection *const ifmConn, const Shape &writeOffset); - // Creates a MemoryCopy operation for the given ifm/ofm and readOffset. - std::shared_ptr MakeMemoryCopyForSplitOps(const TensorConnection *const ofmConn, - const TensorConnection *const ifmConn, const Shape &readShape, const Shape &readOffset); - - // Creates the desired Output shape of StridedSlice. - // - // returns the Desired shape. - Shape MakeStridedSliceDesiredShape(Operation *const operation, const Shape &baseShape); Operation *MakeDepthwiseMeanOp(const TensorConnection *ifmConn, const Shape &ifmShape4D, const Shape &readShape, const Shape &readOffset, const Shape &ofmShape4D, int w, int h, const std::string &name, std::shared_ptr &weightTensor, @@ -98,8 +88,8 @@ private: Operation *ConvertExpToLUT(Graph *const graph, Operation *const operation); Operation *RewritePack(Graph *const graph, Operation *const operation); Operation *RewriteUnpack(Graph *const graph, Operation *const operation); - Operation *RewriteSplit(Graph *const graph, Operation *const operation); Operation *RewriteSlice(Graph *const graph, Operation *const operation); + Operation *RewriteStridedSlice(Graph *const graph, Operation *const operation); Operation *RemoveReshape(Graph *const graph, Operation *const operation); Operation *ConvertReverse(Graph *const graph, Operation *const operation); Operation *ConvertGather(Graph *const graph, Operation *const operation); @@ -190,16 +180,11 @@ public: {}, { &TFLiteGraphOptimiser::RewriteSlice, + &TFLiteGraphOptimiser::RewriteStridedSlice, &TFLiteGraphOptimiser::RewritePack, &TFLiteGraphOptimiser::RewriteUnpack } }, - { - {}, - { - &TFLiteGraphOptimiser::RewriteSplit - } - }, { {}, { -- GitLab