diff --git a/ethosu/regor/test/test_tflite_supported_operators.cpp b/ethosu/regor/test/test_tflite_supported_operators.cpp index 3493e3ff85a24ddb684ed5f5c5ce5583a2786452..286b2695419c4cce223a9f63a68651b7ea02f76b 100644 --- a/ethosu/regor/test/test_tflite_supported_operators.cpp +++ b/ethosu/regor/test/test_tflite_supported_operators.cpp @@ -278,6 +278,71 @@ TEST_CASE("Supported operators Common") } op->Disconnect(); } + + SECTION("ConstrainMaxPoolKernel") + { + auto op = CreateOperation(OpType::MaxPool, Shape(1, 1000, 1000, 1), DataType::Int8, Shape(1, 1000, 1000, 1), DataType::Int8); + + auto SetKernel = [&op](int h, int w) + { + auto kernel = std::make_unique(Point2i{w, h}, Point2i{1, 1}, Point2i{1, 1}, 1, Margin{0, 0, 0, 0}); + op->SetKernel(std::move(kernel)); + auto ofmConn = op->Output(TensorUsage::OFM); + auto ifmConn = op->Input(TensorUsage::IFM); + auto &ofmShape = ofmConn->shape; + auto &ifmShape = ofmConn->shape; + ofmShape = ifmShape.WithWidth(ifmShape.Width() - w).WithHeight(ifmShape.Height() - h); + }; + SetKernel(8, 8); + REQUIRE(supportedOps->Check(op.get()) == true); + SetKernel(256, 256); + REQUIRE(supportedOps->Check(op.get()) == true); + SetKernel(256, 257); + REQUIRE(supportedOps->Check(op.get()) == false); + SetKernel(257, 256); + REQUIRE(supportedOps->Check(op.get()) == false); + op->Disconnect(); + } + + SECTION("ConstrainAvgPoolKernel") + { + auto op = CreateOperation(OpType::AvgPool, Shape(1, 100, 100, 1), DataType::Int8, Shape(1, 100, 100, 1), DataType::Int8); + + auto SetKernel = [&op](int h, int w, int sh = 1, int sw = 1, int ph = 0, int pw = 0) + { + int t = ph / 2; + int b = ph - t; + int l = pw / 2; + int r = pw - l; + auto kernel = std::make_unique(Point2i{w, h}, Point2i{sw, sh}, Point2i{1, 1}, 1, Margin{t, b, l, r}); + op->SetKernel(std::move(kernel)); + auto ofmConn = op->Output(TensorUsage::OFM); + auto ifmConn = op->Input(TensorUsage::IFM); + auto &ofmShape = ofmConn->shape; + auto &ifmShape = ofmConn->shape; + ofmShape = ifmShape.WithWidth((ifmShape.Width() - w + pw) / sw).WithHeight((ifmShape.Height() - h + ph) / sh); + }; + // max size (VALID padding) + SetKernel(256, 256, 1, 1, 0, 0); + REQUIRE(supportedOps->Check(op.get()) == true); + // too large prod (VALID padding) + SetKernel(256, 257, 1, 1, 0, 0); + REQUIRE(supportedOps->Check(op.get()) == false); + // too large height (VALID padding) + SetKernel(257, 8, 1, 1, 0, 0); + REQUIRE(supportedOps->Check(op.get()) == false); + + // max size (SAME padding) + SetKernel(8, 8, 1, 1, 1, 1); + REQUIRE(supportedOps->Check(op.get()) == true); + // too large width (SAME padding) + SetKernel(8, 9, 1, 1, 1, 1); + REQUIRE(supportedOps->Check(op.get()) == false); + // OK if width matches stride + SetKernel(8, 9, 1, 9, 1, 1); + REQUIRE(supportedOps->Check(op.get()) == true); + op->Disconnect(); + } } TEST_CASE("Supported operators EthosU55") diff --git a/ethosu/regor/tflite/tflite_supported_operators.cpp b/ethosu/regor/tflite/tflite_supported_operators.cpp index 07a5184e9db88577fc199e84428be0c859697e8a..d7d1f3866f433ebff644a77668529a2b9cb792bf 100644 --- a/ethosu/regor/tflite/tflite_supported_operators.cpp +++ b/ethosu/regor/tflite/tflite_supported_operators.cpp @@ -392,6 +392,75 @@ bool TfLiteSupportedOperators::ConstraintBias(const Operation *op) return true; } +bool TfLiteSupportedOperators::ConstraintAvgPool(const Operation *op) +{ + OpType opType = op->Type(); + if ( opType != OpType::AvgPool ) + { + return true; + } + auto kernel = op->Kernel(); + assert(kernel); + auto [w, h] = kernel->Size(); + auto [sw, sh] = kernel->Stride(); + if ( kernel->Padding().IsZero() ) + { + // VALID padding + if ( h > 256 || h < 1 ) + { + Failure(op, fmt::format("kernel height: {} out of range", h), "When padding=VALID, kernel-height must be in the range (1,256)"); + return false; + } + if ( h * w > 256 * 256 ) + { + Failure(op, fmt::format("kernel product: {} out of range", h * w), + "When padding=VALID, kernel product (H*W) must be in the range (1, 256*256)"); + return false; + } + } + else + { + // SAME padding + if ( w != sw && (w > 8 || w < 1) ) + { + // kernel width out of range + Failure(op, fmt::format("kernel width: {} out of range", w), + "When padding=SAME, kernel width must be in the range (1,8) OR equal to the stride(width)"); + return false; + } + if ( h > 8 || h < 1 ) + { + Failure(op, fmt::format("kernel height: {} out of range", h), "When padding=SAME, kernel height must be in the range (1,8)"); + return false; + } + } + return true; +} + +bool TfLiteSupportedOperators::ConstraintMaxPool(const Operation *op) +{ + OpType opType = op->Type(); + if ( opType != OpType::MaxPool ) + { + return true; + } + auto kernel = op->Kernel(); + assert(kernel); + auto [w, h] = kernel->Size(); + auto [sw, sh] = kernel->Stride(); + if ( h > 256 || h < 1 ) + { + Failure(op, fmt::format("kernel height: {} out of range", h), "Kernel height must be in the range (1, 256)"); + return false; + } + if ( h * w > 256 * 256 ) + { + Failure(op, fmt::format("kernel product: {} out of range", h * w), "Kernel product must be in the range (1, 256 * 256)"); + return false; + } + return true; +} + void TfLiteSupportedOperators::Failure(const Operation *op, const std::string &message, const std::string &constraint) { assert(op); @@ -434,6 +503,8 @@ TfLiteSupportedOperators::TfLiteSupportedOperators(IArchitectureConstraints *con &TfLiteSupportedOperators::ConstraintWeightsPrecision, &TfLiteSupportedOperators::ConstraintWeightSum, &TfLiteSupportedOperators::ConstraintBias, + &TfLiteSupportedOperators::ConstraintAvgPool, + &TfLiteSupportedOperators::ConstraintMaxPool, }; } diff --git a/ethosu/regor/tflite/tflite_supported_operators.hpp b/ethosu/regor/tflite/tflite_supported_operators.hpp index 5b3ea01f66cfb27fe8bcc581e5030f66df44cf85..943203e67b4ba3dfc3dd693e20301af2e5996bc7 100644 --- a/ethosu/regor/tflite/tflite_supported_operators.hpp +++ b/ethosu/regor/tflite/tflite_supported_operators.hpp @@ -63,5 +63,7 @@ private: bool ConstraintWeightsPrecision(const Operation *op); bool ConstraintWeightSum(const Operation *op); bool ConstraintBias(const Operation *op); + bool ConstraintAvgPool(const Operation *op); + bool ConstraintMaxPool(const Operation *op); }; } // namespace regor