From 3618e3467b61dc2b3e6df1afbf9eb806f242b93f Mon Sep 17 00:00:00 2001 From: Jacob Bohlin Date: Thu, 22 May 2025 15:13:01 +0100 Subject: [PATCH] MLBEDSW-9535 MatMul decomposition Change-Id: I60855fe2aa57db32f80b9f057667574bb0d56d04 Signed-off-by: Jacob Bohlin --- ethosu/regor/compiler/scheduler_decompose.cpp | 77 ++++++++++++++++--- 1 file changed, 68 insertions(+), 9 deletions(-) diff --git a/ethosu/regor/compiler/scheduler_decompose.cpp b/ethosu/regor/compiler/scheduler_decompose.cpp index 44032b83..18bcc2e8 100644 --- a/ethosu/regor/compiler/scheduler_decompose.cpp +++ b/ethosu/regor/compiler/scheduler_decompose.cpp @@ -1488,25 +1488,84 @@ std::vector> DecomposeMatmul(Architecture *a auto ifmConn = op->Input(TensorUsage::IFM); auto &ifmShape = ifmConn->SliceShape(); auto &ifmSlice = ifmConn->slice; + auto ifm2Conn = op->Input(TensorUsage::IFM1); + auto &ifm2Shape = ifm2Conn->SliceShape(); + auto &ifm2Slice = ifm2Conn->slice; ofmSlice.Initialize(ofmShape.WithZeros(), ofmShape); ifmSlice.Initialize(ifmShape.WithZeros(), ifmShape); + ifm2Slice.Initialize(ifm2Shape.WithZeros(), ifm2Shape); - // TODO MLBEDSW-9535: large tensor decomposition - if ( auto ifm2Conn = op->TryInput(TensorUsage::IFM1) ) - { - auto &ifm2Shape = ifm2Conn->shape; - auto &ifm2Slice = ifm2Conn->slice; - - ifm2Slice.Initialize(ifm2Shape.WithZeros(), ifm2Shape); - } - + // Decompose Batching auto ofmRank = ofmShape.Size(); if ( ofmRank > 2 && (ofmShape.Elements() > ofmShape.Width() * ofmShape.Depth()) ) { return DecomposeLeadingDimensions(ofmRank - 2, arch, std::move(op), DecomposeMatmul); } + // Define total dimensions of input and output matrices + int OH = ofmShape.Width(); // Num rows in output + int OW = ofmShape.Depth(); // Num columns in output + int IC = ifmShape.Depth(); // Num elements in contracted axis + + if ( OH > MAX_DIM || OW > MAX_DIM || IC > MAX_DIM ) + { + // Define block dimensions + int BH = std::min(OH, MAX_DIM); // Num rows in output block + int BW = std::min(OW, MAX_DIM); // Num columns in output block + int BC = std::min(IC, MAX_DIM); // Num channels in input blocks + if ( IC > MAX_DIM && arch->Constraints()->SupportsAccumulatorSaveRestore() ) + { + // Splitting the op into sup-ops constrained by the block height/width ensures + // that the accumulators are not overwritten before the vector product completes along + // the contracted axis. + auto blockConfigHW = GetOpConfig(arch, op.get())->OptimalStripeGranule(); + BH = std::min(BH, blockConfigHW.y); + BW = std::min(BW, blockConfigHW.x); + } + + for ( int height = 0; height < OH; height += BH ) + { + for ( int width = 0; width < OW; width += BW ) + { + for ( int depth = 0; depth < IC; depth += BC ) + { + auto newIfmSlice = ifmSlice; + newIfmSlice.shape = Shape::PadAxes( + Shape::Min(Shape(OH, IC) - Shape(height, depth), Shape(BH, BC)), ifmShape.Size(), 1); + newIfmSlice.offset += Shape::PadAxes(Shape(height, depth), ifmShape.Size(), 0); + + auto newIfm2Slice = ifm2Slice; + newIfm2Slice.shape = Shape::PadAxes( + Shape::Min(Shape(OW, IC) - Shape(width, depth), Shape(BW, BC)), ifm2Shape.Size(), 1); + newIfm2Slice.offset += Shape::PadAxes(Shape(width, depth), ifm2Shape.Size(), 0); + + auto newOfmSlice = ofmSlice; + newOfmSlice.shape = Shape::PadAxes( + Shape::Min(Shape(OH, OW) - Shape(height, width), Shape(BH, BW)), ofmShape.Size(), 1); + newOfmSlice.offset += Shape::PadAxes(Shape(height, width), ofmShape.Size(), 0); + + std::unique_ptr subOp = MakeSubOperation(op.get()); + subOp->Input(TensorUsage::IFM)->slice = newIfmSlice; + subOp->Input(TensorUsage::IFM1)->slice = newIfm2Slice; + subOp->Output(TensorUsage::OFM)->slice = newOfmSlice; + + // Set accumulator mode according to these conditions: + // * Reset accumulators if first depth-wise block otherwise preserve + // * Enable output if last depth-wise block + AccumulatorControl accMode = { + depth == 0 ? AccumulatorSource::Reset : AccumulatorSource::Acc, + depth + BC >= IC ? true : false, + }; + subOp->SetAccumulatorMode(accMode); + result.emplace_back(std::move(subOp)); + } + } + } + + return result; + } + result.emplace_back(std::move(op)); return result; } -- GitLab