From 6018c9da7791037f7950c2e90695854f90a4b63c Mon Sep 17 00:00:00 2001 From: Alexander Bengtsson Date: Thu, 27 Mar 2025 10:14:21 +0100 Subject: [PATCH] MLBEDSW-10565: Add supported-op checks for Transpose - Add Transpose to ConstraintConstParams - Add ConstraintTransposeDims: Constrains permutation-vector to max 8D - Add TfLiteSupportedOperatorsU55::ConstraintTranspose Change-Id: I81bf6c0ff8b86fc4f60c6504942c0319ef01f7a3 Signed-off-by: Alexander Bengtsson --- .../tflite/tflite_supported_operators.cpp | 19 +++ .../tflite/tflite_supported_operators.hpp | 1 + .../tflite/tflite_supported_operators_u55.cpp | 117 ++++++++++++++++++ .../tflite/tflite_supported_operators_u55.hpp | 1 + 4 files changed, 138 insertions(+) diff --git a/ethosu/regor/tflite/tflite_supported_operators.cpp b/ethosu/regor/tflite/tflite_supported_operators.cpp index 4ceb9779..9e047932 100644 --- a/ethosu/regor/tflite/tflite_supported_operators.cpp +++ b/ethosu/regor/tflite/tflite_supported_operators.cpp @@ -581,6 +581,7 @@ bool TfLiteSupportedOperators::ConstraintConstParams(const Operation *op) case OpType::Pad: case OpType::PadV2: case OpType::MirrorPad: + case OpType::Transpose: break; default: return true; @@ -747,6 +748,23 @@ bool TfLiteSupportedOperators::ConstraintPad(const Operation *op) return true; } +bool TfLiteSupportedOperators::ConstraintTransposeDims(const Operation *op) +{ + OpType opType = op->Type(); + if ( opType != OpType::Transpose ) + { + return true; + } + auto params = op->Input(TensorUsage::Params); + assert(params); + if ( params->shape.Depth() > 8 ) + { + Failure(op, "Unsupported transpose-shape", "tensor dimension must be <= 8"); + return false; + } + return true; +} + void TfLiteSupportedOperators::Failure(const Operation *op, const std::string &message, const std::string &constraint) { assert(op); @@ -804,6 +822,7 @@ TfLiteSupportedOperators::TfLiteSupportedOperators(IArchitectureConstraints *con &TfLiteSupportedOperators::ConstraintMean, &TfLiteSupportedOperators::ConstraintSoftmax, &TfLiteSupportedOperators::ConstraintPad, + &TfLiteSupportedOperators::ConstraintTransposeDims, }; } diff --git a/ethosu/regor/tflite/tflite_supported_operators.hpp b/ethosu/regor/tflite/tflite_supported_operators.hpp index 66584a19..24133cca 100644 --- a/ethosu/regor/tflite/tflite_supported_operators.hpp +++ b/ethosu/regor/tflite/tflite_supported_operators.hpp @@ -72,6 +72,7 @@ private: bool ConstraintMean(const Operation *op); bool ConstraintSoftmax(const Operation *op); bool ConstraintPad(const Operation *op); + bool ConstraintTransposeDims(const Operation *op); }; // Factory for supported-ops checkers diff --git a/ethosu/regor/tflite/tflite_supported_operators_u55.cpp b/ethosu/regor/tflite/tflite_supported_operators_u55.cpp index e10c400d..36cce2b5 100644 --- a/ethosu/regor/tflite/tflite_supported_operators_u55.cpp +++ b/ethosu/regor/tflite/tflite_supported_operators_u55.cpp @@ -22,6 +22,8 @@ #include "common/logging.hpp" #include "compiler/op_type.hpp" +#include "compiler/operation_util.hpp" +#include "compiler/shape_util.hpp" namespace regor { @@ -100,6 +102,7 @@ TfLiteSupportedOperatorsU55::TfLiteSupportedOperatorsU55(IArchitectureConstraint &TfLiteSupportedOperatorsU55::ConstraintKernelStride, &TfLiteSupportedOperatorsU55::ConstraintUnrolledKernelStride, &TfLiteSupportedOperatorsU55::ConstraintMatmul, + &TfLiteSupportedOperatorsU55::ConstraintTranspose, }; } @@ -165,6 +168,120 @@ bool TfLiteSupportedOperatorsU55::ConstraintReverse(const Operation *op) return true; } +bool TfLiteSupportedOperatorsU55::ConstraintTranspose(const Operation *op) +{ + OpType opType = op->Type(); + if ( opType != OpType::Transpose ) + { + return true; + } + auto ifmConn = op->Input(TensorUsage::IFM); + auto ofmConn = op->Output(TensorUsage::OFM); + auto ifmShape = Shape::PadAxes(ifmConn->shape, 4, 1); + auto ifmType = ifmConn->tensor->Type(); + auto *params = op->Input(TensorUsage::Params); + assert(params); + Shape perm = TensorToShape(params->tensor.get(), params->shape.Depth()); + auto transposeMask = TransposeTypeFromShape(perm); + if ( ifmType == DataType::Int32 ) + { + static const char *constraint = + "IFM Shape constraints for 32-bit Transpose:\n" + " * Rank must be less than or equal to 4\n" + " * Max shape based on permutation:\n" + " NHWC: C <= 2^16\n" + " NWHC: N ==1, H <= 2^16, W <= 2^16, C <= 2^14\n" + " NHCW: N*H <= 2^16, W <= 2^16, C <= 2^16\n" + " Any other permutation vector is unsupported"; + if ( ifmShape.Size() > 4 ) + { + Failure(op, fmt::format("32-bit transpose with rank > 4: {}", ifmShape.ToString()), constraint); + return false; + } + switch ( transposeMask ) + { + case TransposeType::None: + // 32-bit NHWC: C-axis must be 0->32768 + if ( ifmShape.Depth() > (1 << 15) ) + { + Failure(op, fmt::format("32-bit NHWC transpose with depth > 32768: {}", ifmShape.ToString()), constraint); + return false; + } + break; + case TransposeType::NWHC: + { + // 32-bit NWHC: max-shape (1,65536,65536,16384) + const static Shape maxShape = Shape(1, (1 << 16), (1 << 16), (1 << 14)); + if ( ifmShape.GreaterMask(maxShape) > 0 ) + { + Failure(op, fmt::format("32-bit NWHC transpose with shape out of range: {}", ifmShape.ToString()), constraint); + return false; + } + } + break; + case TransposeType::NHCW: + { + // 32-bit NHCW: (N*H: 65536, W: 65536, C: 65536) + const static Shape maxShape = Shape((1 << 16), (1 << 16), (1 << 16)); + Shape ifmSquashed = ifmShape.WithHeight(ifmShape.Height() * ifmShape.Batch()).WithBatch(1); + if ( ifmSquashed.GreaterMask(maxShape) > 0 ) + { + Failure(op, fmt::format("32-bit NHCW transpose with shape out of range: {}", ifmSquashed.ToString()), constraint); + return false; + } + } + break; + default: + Failure(op, "Unsupported transpose-type", constraint); + return false; + } + } + else + { + static const char *constraint = + "IFM shape constraints for 8 or 16-bit Transpose:\n" + " * Max shape based on permutation:\n" + " NHWC: no shape constraints\n" + " ELSE IF Rank <= 4D and permutation is: NWHC/NHCW/NCWH:\n" + " (N*H, W, C) <= (2^16, 2^16, 2^16)\n" + " ELSE:\n" + " Product of elements must be less than or equal to 2^16."; + if ( transposeMask == TransposeType::None ) + { + // NHWC: any size is supported + return true; + } + if ( (ifmShape.Size() <= 4) && + (transposeMask == TransposeType::NWHC || transposeMask == TransposeType::NHCW || transposeMask == TransposeType::NCWH) ) + { + // Directly HW-supported transpose-masks + // NWHC/NHCW/NCWH: (N*H: 65536, 65536, 65536) + const static Shape maxShape = Shape((1 << 16), (1 << 16), (1 << 16)); + Shape ifmSquashed = ifmShape.WithHeight(ifmShape.Height() * ifmShape.Batch()).WithBatch(1); + if ( ifmSquashed.GreaterMask(maxShape) > 0 ) + { + Failure(op, + fmt::format("Transpose with permutation {} has shape out of range: {}", EnumToString(transposeMask), + ifmSquashed.ToString()), + constraint); + return false; + } + } + else + { + // Decomposed transpose-masks + // Axis product must be less or equal to 65536 + if ( ifmShape.Elements64() > (1 << 16) ) + { + Failure(op, + fmt::format("Transpose with permutation {} has shape out of range: {}", perm.ToString(), ifmShape.ToString()), constraint); + return false; + } + } + } + return true; +} + bool TfLiteSupportedOperatorsU55::Constraint32bitOps(const Operation *op) { static const std::unordered_set supported = { diff --git a/ethosu/regor/tflite/tflite_supported_operators_u55.hpp b/ethosu/regor/tflite/tflite_supported_operators_u55.hpp index 2f5fa8ab..1daa1743 100644 --- a/ethosu/regor/tflite/tflite_supported_operators_u55.hpp +++ b/ethosu/regor/tflite/tflite_supported_operators_u55.hpp @@ -49,5 +49,6 @@ private: bool ConstraintArgMaxDepth(const Operation *op); bool ConstraintArgMaxAxis(const Operation *op); bool ConstraintArgMaxOverflow(const Operation *op); // TODO: Remove after MLBEDSW-9758: TOSA MaxPool decomp + bool ConstraintTranspose(const Operation *op); }; } // namespace regor -- GitLab