From e56bd6e76f71ced2039e5f9451684f860107920b Mon Sep 17 00:00:00 2001 From: Alexander Bengtsson Date: Fri, 21 Mar 2025 14:13:26 +0100 Subject: [PATCH] MLBEDSW-10577: Add kernel-stride constraints for Ethos-U55 - ConstraintKernelStride Constrains stride to (1,3) for non Conv2D opTypes - ConstraintUnrolledKernelStride Constrains stride for Conv2D based on Unrolling conditions - Change kernel-size constraints for AvgPool to align with stride-constraints Change-Id: I715b26a95bbd6b172b2dcbea94a7ca61e4d17450 Signed-off-by: Alexander Bengtsson --- .../test/test_tflite_supported_operators.cpp | 31 ++++++++-- .../tflite/tflite_supported_operators.cpp | 6 +- .../tflite/tflite_supported_operators_u55.cpp | 61 +++++++++++++++++++ .../tflite/tflite_supported_operators_u55.hpp | 2 + 4 files changed, 92 insertions(+), 8 deletions(-) diff --git a/ethosu/regor/test/test_tflite_supported_operators.cpp b/ethosu/regor/test/test_tflite_supported_operators.cpp index 2970e346..ceece6ec 100644 --- a/ethosu/regor/test/test_tflite_supported_operators.cpp +++ b/ethosu/regor/test/test_tflite_supported_operators.cpp @@ -341,9 +341,6 @@ TEST_CASE("Supported operators Common") // too large width (SAME padding) SetKernel(8, 9, 1, 1, 1, 1); REQUIRE(supportedOps->Check(op.get()) == false); - // OK if width matches stride - SetKernel(8, 9, 1, 9, 1, 1); - REQUIRE(supportedOps->Check(op.get()) == true); op->Disconnect(); } @@ -554,7 +551,6 @@ TEST_CASE("Supported operators EthosU55") op->Disconnect(); } - SECTION("Constraint32BitOps") { auto op = CreateOperation(OpType::Add, Shape(1, 1, 1, 1), DataType::Int32, Shape(1, 1, 1, 1), DataType::Int32, @@ -565,6 +561,33 @@ TEST_CASE("Supported operators EthosU55") op->Disconnect(); op2->Disconnect(); } + + SECTION("ConstraintStride") + { + { + auto op = CreateOperation(OpType::MaxPool, Shape(1, 10, 10, 1), DataType::Int8, Shape(1, 10, 10, 1), DataType::Int8); + auto kernel = std::make_unique(Point2i{1, 1}, Point2i{1, 1}, Point2i{1, 1}, 1, Margin{0, 0, 0, 0}); + op->SetKernel(std::move(kernel)); + REQUIRE(supportedOps->Check(op.get()) == true); + op->Disconnect(); + } + { + // stride > 3 is not supported for MaxPool + auto op = CreateOperation(OpType::MaxPool, Shape(1, 10, 10, 1), DataType::Int8, Shape(1, 2, 2, 1), DataType::Int8); + auto kernel = std::make_unique(Point2i{1, 1}, Point2i{5, 5}, Point2i{1, 1}, 1, Margin{0, 0, 0, 0}); + op->SetKernel(std::move(kernel)); + REQUIRE(supportedOps->Check(op.get()) == false); + op->Disconnect(); + } + { + // stride > 3 is supported for Conv2D (it's unrolled) + auto op = CreateOperation(OpType::Conv2D, Shape(1, 10, 10, 1), DataType::Int8, Shape(1, 2, 2, 1), DataType::Int8); + auto kernel = std::make_unique(Point2i{1, 1}, Point2i{5, 5}, Point2i{1, 1}, 1, Margin{0, 0, 0, 0}); + op->SetKernel(std::move(kernel)); + REQUIRE(supportedOps->Check(op.get()) == true); + op->Disconnect(); + } + } } TEST_CASE("Supported operators EthosU85") diff --git a/ethosu/regor/tflite/tflite_supported_operators.cpp b/ethosu/regor/tflite/tflite_supported_operators.cpp index cda0dad1..2c1a779d 100644 --- a/ethosu/regor/tflite/tflite_supported_operators.cpp +++ b/ethosu/regor/tflite/tflite_supported_operators.cpp @@ -401,7 +401,6 @@ bool TfLiteSupportedOperators::ConstraintAvgPool(const Operation *op) auto kernel = op->Kernel(); assert(kernel); auto [w, h] = kernel->Size(); - auto [sw, sh] = kernel->Stride(); if ( kernel->Padding().IsZero() ) { // VALID padding @@ -420,11 +419,10 @@ bool TfLiteSupportedOperators::ConstraintAvgPool(const Operation *op) else { // SAME padding - if ( w != sw && (w > 8 || w < 1) ) + if ( w > 8 || w < 1 ) { // kernel width out of range - Failure(op, fmt::format("kernel width: {} out of range", w), - "When padding=SAME, kernel width must be in the range (1,8) OR equal to the stride(width)"); + Failure(op, fmt::format("kernel width: {} out of range", w), "When padding=SAME, kernel width must be in the range (1,8)"); return false; } if ( h > 8 || h < 1 ) diff --git a/ethosu/regor/tflite/tflite_supported_operators_u55.cpp b/ethosu/regor/tflite/tflite_supported_operators_u55.cpp index 5d49a09c..62f861c8 100644 --- a/ethosu/regor/tflite/tflite_supported_operators_u55.cpp +++ b/ethosu/regor/tflite/tflite_supported_operators_u55.cpp @@ -94,6 +94,8 @@ TfLiteSupportedOperatorsU55::TfLiteSupportedOperatorsU55(IArchitectureConstraint &TfLiteSupportedOperatorsU55::ConstraintBroadcastShapes, &TfLiteSupportedOperatorsU55::ConstraintReverse, &TfLiteSupportedOperatorsU55::Constraint32bitOps, + &TfLiteSupportedOperatorsU55::ConstraintKernelStride, + &TfLiteSupportedOperatorsU55::ConstraintUnrolledKernelStride, }; } @@ -195,4 +197,63 @@ bool TfLiteSupportedOperatorsU55::Constraint32bitOps(const Operation *op) } return true; } + +bool TfLiteSupportedOperatorsU55::ConstraintKernelStride(const Operation *op) +{ + const auto kernel = op->Kernel(); + assert(kernel); + const int32_t stride_w = kernel->Stride().x; + const int32_t stride_h = kernel->Stride().y; + if ( op->Type() == OpType::Conv2D ) + { + // Conv2D is handled by ConstraintUnrolledKernelStride + return true; + } + if ( stride_w > 3 || stride_h > 3 ) + { + Failure(op, fmt::format("Unsupported kernel stride: {}, {}", stride_w, stride_h), "kernel stride must be in the range (1,3)"); + return false; + } + return true; +} + +bool TfLiteSupportedOperatorsU55::ConstraintUnrolledKernelStride(const Operation *op) +{ + // Constraints for UnrollConv + const static char *constraint = + "Stride >3 is only supported when:\n" + "\t * kernel dilation = 1\n" + "\t * IFM and OFM are not sliced\n" + "\t * padding = VALID\n"; + const auto ifmConn = op->Input(TensorUsage::IFM); + const auto ofmConn = op->Output(TensorUsage::OFM); + const auto kernel = op->Kernel(); + assert(ifmConn); + assert(ofmConn); + assert(kernel); + if ( op->Type() != OpType::Conv2D ) + { + return true; + } + const int32_t stride_w = kernel->Stride().x; + const int32_t stride_h = kernel->Stride().y; + if ( stride_w <= 3 && stride_h <= 3 ) + { + // always supported + return true; + } + // stride > 3 requires unrolling, check unroll conditions + const bool hasPadding = !kernel->Padding().IsZero(); + const bool hasIfmSlice = ifmConn->slice.shape.IsValid() || ifmConn->slice.offset.IsValid(); + const bool hasOfmSlice = ofmConn->slice.shape.IsValid() || ofmConn->slice.offset.IsValid(); + const int32_t dilation_h = kernel->Dilation().y; + const int32_t dilation_w = kernel->Dilation().x; + const bool canUnroll = !hasPadding && !hasIfmSlice && !hasOfmSlice && (dilation_h == 1) && (dilation_w == 1); + if ( !canUnroll ) + { + Failure(op, fmt::format("Unsupported kernel stride: {}, {}", stride_w, stride_h), constraint); + } + return true; +} + } // namespace regor diff --git a/ethosu/regor/tflite/tflite_supported_operators_u55.hpp b/ethosu/regor/tflite/tflite_supported_operators_u55.hpp index 74cfe59a..1d14f403 100644 --- a/ethosu/regor/tflite/tflite_supported_operators_u55.hpp +++ b/ethosu/regor/tflite/tflite_supported_operators_u55.hpp @@ -43,5 +43,7 @@ private: bool ConstraintBroadcastShapes(const Operation *op); bool ConstraintReverse(const Operation *op); bool Constraint32bitOps(const Operation *op); + bool ConstraintKernelStride(const Operation *op); + bool ConstraintUnrolledKernelStride(const Operation *op); }; } // namespace regor -- GitLab