diff --git a/ethosu/regor/compiler/scheduler_decompose.cpp b/ethosu/regor/compiler/scheduler_decompose.cpp index defa6679c8627c8cacbc02464d05288eb2929ca2..60c4023accd785a2cd38a4aed2807a9773ba199b 100644 --- a/ethosu/regor/compiler/scheduler_decompose.cpp +++ b/ethosu/regor/compiler/scheduler_decompose.cpp @@ -219,7 +219,7 @@ bool CanRunOnHardware(Architecture *arch, const SchedulerOperation *schedOp) return false; } } - if ( IsConvolution(schedOp->Type()) ) + if ( IsConvolution(schedOp->Type()) || IsPooling(schedOp->Type()) ) { auto &ofmShape = schedOp->OFM()->SliceShape(); if ( ofmShape.Size() > 3 && ofmShape.Batch() > 1 ) return false; @@ -279,6 +279,7 @@ bool CanDecompose(Architecture *, const SchedulerOperation *schedOp) if ( schedOp->Type() == OpType::ArgMax ) return true; if ( schedOp->Type() == OpType::Reverse ) return true; if ( schedOp->Type() == OpType::Transpose ) return true; + if ( schedOp->Type() == OpType::MaxPool ) return true; return false; } @@ -1696,4 +1697,36 @@ std::vector> DecomposeTranspose(Architecture return result; } +std::vector> DecomposeMaxPool(Architecture *arch, std::unique_ptr op) +{ + std::vector> result; + auto ofmConn = op->Output(TensorUsage::OFM); + auto &ofmShape = ofmConn->SliceShape(); + auto &ofmSlice = ofmConn->slice; + auto ifmConn = op->Input(TensorUsage::IFM); + auto &ifmShape = ifmConn->SliceShape(); + auto &ifmSlice = ifmConn->slice; + + ofmSlice.Initialize(ofmShape.WithZeros(), ofmShape); + ifmSlice.Initialize(ifmShape.WithZeros(), ifmShape); + + if ( auto ifm2Conn = op->TryInput(TensorUsage::IFM1) ) + { + auto &ifm2Shape = ifm2Conn->shape; + auto &ifm2Slice = ifm2Conn->slice; + + ifm2Slice.Initialize(ifm2Shape.WithZeros(), ifm2Shape); + } + + auto ofmRank = ofmShape.Size(); + if ( ofmRank > 3 && (ofmShape.Elements() > ofmShape.Height() * ofmShape.Width() * ofmShape.Depth()) ) + { + return DecomposeLeadingDimensions(ofmRank - 3, arch, std::move(op), DecomposeMaxPool); + } + + result.emplace_back(std::move(op)); + return result; +} + + } // namespace regor diff --git a/ethosu/regor/compiler/scheduler_decompose.hpp b/ethosu/regor/compiler/scheduler_decompose.hpp index 76e4c99ca76e9d822dd91eee7f79e81ab4312e32..20c392a70ea53dbf6a67e32331dfdc7461de48ce 100644 --- a/ethosu/regor/compiler/scheduler_decompose.hpp +++ b/ethosu/regor/compiler/scheduler_decompose.hpp @@ -44,6 +44,7 @@ std::vector> DecomposeMatmul(Architecture *a std::vector> DecomposeReduce(Architecture *arch, std::unique_ptr op); std::vector> DecomposeReverse(Architecture *arch, std::unique_ptr op); std::vector> DecomposeTranspose(Architecture *arch, std::unique_ptr op); +std::vector> DecomposeMaxPool(Architecture *arch, std::unique_ptr op); // Operator query helpers diff --git a/ethosu/regor/compiler/scheduler_packing.cpp b/ethosu/regor/compiler/scheduler_packing.cpp index 53aa06c0d0e91613d836752c277f9ed773a91ab5..3b56b1a3cf5c572b53d52c2ed0177dbf68f685c4 100644 --- a/ethosu/regor/compiler/scheduler_packing.cpp +++ b/ethosu/regor/compiler/scheduler_packing.cpp @@ -589,6 +589,9 @@ std::vector> SchedulerPacking::DecomposeSche case OpType::Transpose: result = DecomposeTranspose(_arch, std::move(op)); break; + case OpType::MaxPool: + result = DecomposeMaxPool(_arch, std::move(op)); + break; default: if ( DecomposeAsElementwise(op->Type()) || op->Type() == OpType::MemoryCopy ) {