diff --git a/ethosu/regor/compiler/graphir_optimiser.cpp b/ethosu/regor/compiler/graphir_optimiser.cpp index 79a49c2dab60cdb8102360412f8cb238fb442552..1e83461a71aceddf066056be930b592959121794 100644 --- a/ethosu/regor/compiler/graphir_optimiser.cpp +++ b/ethosu/regor/compiler/graphir_optimiser.cpp @@ -1065,18 +1065,24 @@ Operation *GraphIrOptimiser::RewritePad(Graph *const, Operation *const operation return returnOp; } -Operation *GraphIrOptimiser::UnrollConv(Graph *const, Operation *const operation) +Operation *GraphIrOptimiser::UnrollKernelStrides(Graph *const, Operation *const operation) { auto returnOp = operation; - if ( operation->Type() == OpType::Conv2D ) + if ( operation->Type() == OpType::Conv2D || operation->Type() == OpType::AvgPool || operation->Type() == OpType::MaxPool ) { const auto ifmConn = operation->Input(TensorUsage::IFM); assert(ifmConn); - const auto weightsConn = operation->Input(TensorUsage::Weights); - assert(weightsConn); - const auto scalesConn = operation->Input(TensorUsage::Scales); - assert(scalesConn); + TensorConnection *weightsConn = nullptr; + TensorConnection *scalesConn = nullptr; + + if ( operation->Type() == OpType::Conv2D ) + { + weightsConn = operation->Input(TensorUsage::Weights); + assert(weightsConn); + scalesConn = operation->Input(TensorUsage::Scales); + assert(scalesConn); + } const auto ofmConn = operation->Output(TensorUsage::OFM); assert(ofmConn); @@ -1098,23 +1104,12 @@ Operation *GraphIrOptimiser::UnrollConv(Graph *const, Operation *const operation const bool hasIfmSlice = ifmConn->slice.shape.IsValid() || ifmConn->slice.offset.IsValid(); const bool hasOfmSlice = ofmConn->slice.shape.IsValid() || ofmConn->slice.offset.IsValid(); - tflite::Padding paddingType = tflite::Padding::VALID; - const tflite::Operator *const passthrough = static_cast(operation->Passthrough()); - if ( passthrough ) - { - const auto options = passthrough->builtin_options_as_Conv2DOptions(); - if ( options ) - { - paddingType = options->padding(); - } - } - // Figure out if op needs to be unrolled const bool needUnrollH = stride_h > 3; const bool needUnrollW = stride_w > 3; // Figure out if op can be unrolled - const bool canUnroll = !hasPadding && !hasIfmSlice && !hasOfmSlice && paddingType == tflite::Padding::VALID; + const bool canUnroll = !hasPadding && !hasIfmSlice && !hasOfmSlice && kernel->Padding().IsZero(); const bool canUnrollH = dilation_h == 1 && canUnroll; const bool canUnrollW = dilation_w == 1 && canUnroll; @@ -1141,8 +1136,14 @@ Operation *GraphIrOptimiser::UnrollConv(Graph *const, Operation *const operation op->SetKernel(std::make_unique(kernel->WithStride({1, 1}))); op->CopyInput(TensorUsage::IFM, *ifmConn); op->Input(TensorUsage::IFM)->Set(ifmSlice); - op->CopyInput(TensorUsage::Weights, *weightsConn); - op->CopyInput(TensorUsage::Scales, *scalesConn); + if ( weightsConn ) + { + op->CopyInput(TensorUsage::Weights, *weightsConn); + } + if ( scalesConn ) + { + op->CopyInput(TensorUsage::Scales, *scalesConn); + } op->CopyOutput(TensorUsage::OFM, *ofmConn); op->Output(TensorUsage::OFM)->Set(ofmSlice); RecordOptimisation(*operation, op.get()); diff --git a/ethosu/regor/compiler/graphir_optimiser.hpp b/ethosu/regor/compiler/graphir_optimiser.hpp index 13691e69476deb4ce84ad156a9e96f052ecb5b06..e36c308e287143a8b5c7aaefade2a3846ee7b39c 100644 --- a/ethosu/regor/compiler/graphir_optimiser.hpp +++ b/ethosu/regor/compiler/graphir_optimiser.hpp @@ -76,7 +76,7 @@ private: Operation *ReshapeReverse(Graph *const graph, Operation *const operation); void MoveToConsumer(const Operation *const operation, Operation *const cons); Operation *MoveSplitSliceToConsumer(Graph *const, Operation *const operation); - Operation *UnrollConv(Graph *const, Operation *const operation); + Operation *UnrollKernelStrides(Graph *const, Operation *const operation); // Utility/Helper methods Operation *MakeFillOperation(TensorConnection *const ofmConn, const Shape &ofmShape, const TensorSlice &ofmSlice, std::shared_ptr padTensor); @@ -166,7 +166,7 @@ private: &GraphIrOptimiser::OptimiseElementwise, &GraphIrOptimiser::RearrangeTranspose, &GraphIrOptimiser::ReshapeReverse, - &GraphIrOptimiser::UnrollConv + &GraphIrOptimiser::UnrollKernelStrides } }, // MoveSplitSliceToConsumer need to be done after any other optimisation that can affect the ifm/ofm shapes diff --git a/ethosu/regor/test/test_tflite_supported_operators.cpp b/ethosu/regor/test/test_tflite_supported_operators.cpp index 05aaf67ef2768833f52674bac2f53eb361504de5..d234b51d62a1d2f0aa73acedc506e6e649006881 100644 --- a/ethosu/regor/test/test_tflite_supported_operators.cpp +++ b/ethosu/regor/test/test_tflite_supported_operators.cpp @@ -591,8 +591,8 @@ TEST_CASE("Supported operators EthosU55") op->Disconnect(); } { - // stride > 3 is not supported for MaxPool - auto op = CreateOperation(OpType::MaxPool, Shape(1, 10, 10, 1), DataType::Int8, Shape(1, 2, 2, 1), DataType::Int8); + // stride > 3 is not supported for Add + auto op = CreateOperation(OpType::Add, Shape(1, 10, 10, 1), DataType::Int8, Shape(1, 10, 10, 1), DataType::Int8); auto kernel = std::make_unique(Point2i{1, 1}, Point2i{5, 5}, Point2i{1, 1}, 1, Margin{0, 0, 0, 0}); op->SetKernel(std::move(kernel)); REQUIRE(supportedOps->Check(op.get()) == false); diff --git a/ethosu/regor/tflite/tflite_supported_operators_u55.cpp b/ethosu/regor/tflite/tflite_supported_operators_u55.cpp index a82f1dd752f09a23b76ade23240b84fa7b0934c3..eda556d8cd4b65f154c0402b54aab0e2b8ce4599 100644 --- a/ethosu/regor/tflite/tflite_supported_operators_u55.cpp +++ b/ethosu/regor/tflite/tflite_supported_operators_u55.cpp @@ -469,9 +469,9 @@ bool TfLiteSupportedOperatorsU55::ConstraintKernelStride(const Operation *op) assert(kernel); const int32_t stride_w = kernel->Stride().x; const int32_t stride_h = kernel->Stride().y; - if ( op->Type() == OpType::Conv2D ) + if ( op->Type() == OpType::Conv2D || op->Type() == OpType::AvgPool || op->Type() == OpType::MaxPool ) { - // Conv2D is handled by ConstraintUnrolledKernelStride + // Conv2D and Pooling is handled by ConstraintUnrolledKernelStride return true; } if ( stride_w > 3 || stride_h > 3 ) @@ -484,7 +484,7 @@ bool TfLiteSupportedOperatorsU55::ConstraintKernelStride(const Operation *op) bool TfLiteSupportedOperatorsU55::ConstraintUnrolledKernelStride(const Operation *op) { - // Constraints for UnrollConv + // Constraints for UnrollKernelStrides const static char *constraint = "Stride >3 is only supported when:\n" "\t * kernel dilation = 1\n" @@ -496,7 +496,7 @@ bool TfLiteSupportedOperatorsU55::ConstraintUnrolledKernelStride(const Operation assert(ifmConn); assert(ofmConn); assert(kernel); - if ( op->Type() != OpType::Conv2D ) + if ( !(op->Type() == OpType::Conv2D || op->Type() == OpType::AvgPool || op->Type() == OpType::MaxPool) ) { return true; } @@ -517,6 +517,7 @@ bool TfLiteSupportedOperatorsU55::ConstraintUnrolledKernelStride(const Operation if ( !canUnroll ) { Failure(op, fmt::format("Unsupported kernel stride: {}, {}", stride_w, stride_h), constraint); + return false; } return true; }