diff --git a/ethosu/regor/test/test_tflite_supported_operators.cpp b/ethosu/regor/test/test_tflite_supported_operators.cpp index 2970e34606af9576afa03799483d2690dd65e32b..ceece6ec322630f56b614e5500417fbb8a2997c5 100644 --- a/ethosu/regor/test/test_tflite_supported_operators.cpp +++ b/ethosu/regor/test/test_tflite_supported_operators.cpp @@ -341,9 +341,6 @@ TEST_CASE("Supported operators Common") // too large width (SAME padding) SetKernel(8, 9, 1, 1, 1, 1); REQUIRE(supportedOps->Check(op.get()) == false); - // OK if width matches stride - SetKernel(8, 9, 1, 9, 1, 1); - REQUIRE(supportedOps->Check(op.get()) == true); op->Disconnect(); } @@ -554,7 +551,6 @@ TEST_CASE("Supported operators EthosU55") op->Disconnect(); } - SECTION("Constraint32BitOps") { auto op = CreateOperation(OpType::Add, Shape(1, 1, 1, 1), DataType::Int32, Shape(1, 1, 1, 1), DataType::Int32, @@ -565,6 +561,33 @@ TEST_CASE("Supported operators EthosU55") op->Disconnect(); op2->Disconnect(); } + + SECTION("ConstraintStride") + { + { + auto op = CreateOperation(OpType::MaxPool, Shape(1, 10, 10, 1), DataType::Int8, Shape(1, 10, 10, 1), DataType::Int8); + auto kernel = std::make_unique(Point2i{1, 1}, Point2i{1, 1}, Point2i{1, 1}, 1, Margin{0, 0, 0, 0}); + op->SetKernel(std::move(kernel)); + REQUIRE(supportedOps->Check(op.get()) == true); + op->Disconnect(); + } + { + // stride > 3 is not supported for MaxPool + auto op = CreateOperation(OpType::MaxPool, Shape(1, 10, 10, 1), DataType::Int8, Shape(1, 2, 2, 1), DataType::Int8); + auto kernel = std::make_unique(Point2i{1, 1}, Point2i{5, 5}, Point2i{1, 1}, 1, Margin{0, 0, 0, 0}); + op->SetKernel(std::move(kernel)); + REQUIRE(supportedOps->Check(op.get()) == false); + op->Disconnect(); + } + { + // stride > 3 is supported for Conv2D (it's unrolled) + auto op = CreateOperation(OpType::Conv2D, Shape(1, 10, 10, 1), DataType::Int8, Shape(1, 2, 2, 1), DataType::Int8); + auto kernel = std::make_unique(Point2i{1, 1}, Point2i{5, 5}, Point2i{1, 1}, 1, Margin{0, 0, 0, 0}); + op->SetKernel(std::move(kernel)); + REQUIRE(supportedOps->Check(op.get()) == true); + op->Disconnect(); + } + } } TEST_CASE("Supported operators EthosU85") diff --git a/ethosu/regor/tflite/tflite_supported_operators.cpp b/ethosu/regor/tflite/tflite_supported_operators.cpp index cda0dad13d8e2ede26cc089ddc224c67486c8472..2c1a779d6cbef4b0cde7789279537919b51f50af 100644 --- a/ethosu/regor/tflite/tflite_supported_operators.cpp +++ b/ethosu/regor/tflite/tflite_supported_operators.cpp @@ -401,7 +401,6 @@ bool TfLiteSupportedOperators::ConstraintAvgPool(const Operation *op) auto kernel = op->Kernel(); assert(kernel); auto [w, h] = kernel->Size(); - auto [sw, sh] = kernel->Stride(); if ( kernel->Padding().IsZero() ) { // VALID padding @@ -420,11 +419,10 @@ bool TfLiteSupportedOperators::ConstraintAvgPool(const Operation *op) else { // SAME padding - if ( w != sw && (w > 8 || w < 1) ) + if ( w > 8 || w < 1 ) { // kernel width out of range - Failure(op, fmt::format("kernel width: {} out of range", w), - "When padding=SAME, kernel width must be in the range (1,8) OR equal to the stride(width)"); + Failure(op, fmt::format("kernel width: {} out of range", w), "When padding=SAME, kernel width must be in the range (1,8)"); return false; } if ( h > 8 || h < 1 ) diff --git a/ethosu/regor/tflite/tflite_supported_operators_u55.cpp b/ethosu/regor/tflite/tflite_supported_operators_u55.cpp index 5d49a09c41e810611ddcd170c8359e205495d2cd..62f861c810244096fcc5523889b5a932102c0929 100644 --- a/ethosu/regor/tflite/tflite_supported_operators_u55.cpp +++ b/ethosu/regor/tflite/tflite_supported_operators_u55.cpp @@ -94,6 +94,8 @@ TfLiteSupportedOperatorsU55::TfLiteSupportedOperatorsU55(IArchitectureConstraint &TfLiteSupportedOperatorsU55::ConstraintBroadcastShapes, &TfLiteSupportedOperatorsU55::ConstraintReverse, &TfLiteSupportedOperatorsU55::Constraint32bitOps, + &TfLiteSupportedOperatorsU55::ConstraintKernelStride, + &TfLiteSupportedOperatorsU55::ConstraintUnrolledKernelStride, }; } @@ -195,4 +197,63 @@ bool TfLiteSupportedOperatorsU55::Constraint32bitOps(const Operation *op) } return true; } + +bool TfLiteSupportedOperatorsU55::ConstraintKernelStride(const Operation *op) +{ + const auto kernel = op->Kernel(); + assert(kernel); + const int32_t stride_w = kernel->Stride().x; + const int32_t stride_h = kernel->Stride().y; + if ( op->Type() == OpType::Conv2D ) + { + // Conv2D is handled by ConstraintUnrolledKernelStride + return true; + } + if ( stride_w > 3 || stride_h > 3 ) + { + Failure(op, fmt::format("Unsupported kernel stride: {}, {}", stride_w, stride_h), "kernel stride must be in the range (1,3)"); + return false; + } + return true; +} + +bool TfLiteSupportedOperatorsU55::ConstraintUnrolledKernelStride(const Operation *op) +{ + // Constraints for UnrollConv + const static char *constraint = + "Stride >3 is only supported when:\n" + "\t * kernel dilation = 1\n" + "\t * IFM and OFM are not sliced\n" + "\t * padding = VALID\n"; + const auto ifmConn = op->Input(TensorUsage::IFM); + const auto ofmConn = op->Output(TensorUsage::OFM); + const auto kernel = op->Kernel(); + assert(ifmConn); + assert(ofmConn); + assert(kernel); + if ( op->Type() != OpType::Conv2D ) + { + return true; + } + const int32_t stride_w = kernel->Stride().x; + const int32_t stride_h = kernel->Stride().y; + if ( stride_w <= 3 && stride_h <= 3 ) + { + // always supported + return true; + } + // stride > 3 requires unrolling, check unroll conditions + const bool hasPadding = !kernel->Padding().IsZero(); + const bool hasIfmSlice = ifmConn->slice.shape.IsValid() || ifmConn->slice.offset.IsValid(); + const bool hasOfmSlice = ofmConn->slice.shape.IsValid() || ofmConn->slice.offset.IsValid(); + const int32_t dilation_h = kernel->Dilation().y; + const int32_t dilation_w = kernel->Dilation().x; + const bool canUnroll = !hasPadding && !hasIfmSlice && !hasOfmSlice && (dilation_h == 1) && (dilation_w == 1); + if ( !canUnroll ) + { + Failure(op, fmt::format("Unsupported kernel stride: {}, {}", stride_w, stride_h), constraint); + } + return true; +} + } // namespace regor diff --git a/ethosu/regor/tflite/tflite_supported_operators_u55.hpp b/ethosu/regor/tflite/tflite_supported_operators_u55.hpp index 74cfe59ac2c895763c6f1f198cb18b96fd3ec0d5..1d14f40370e2a56f76852662ac37db3abaf12285 100644 --- a/ethosu/regor/tflite/tflite_supported_operators_u55.hpp +++ b/ethosu/regor/tflite/tflite_supported_operators_u55.hpp @@ -43,5 +43,7 @@ private: bool ConstraintBroadcastShapes(const Operation *op); bool ConstraintReverse(const Operation *op); bool Constraint32bitOps(const Operation *op); + bool ConstraintKernelStride(const Operation *op); + bool ConstraintUnrolledKernelStride(const Operation *op); }; } // namespace regor