From 3fe41dd0bb917da3d919dcaa0af0e4b2ecc5b6f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johan=20Alfv=C3=A9n?= Date: Wed, 12 Feb 2025 16:32:15 +0100 Subject: [PATCH] MLBEDSW-10424: Fallback to Conv2D for ReduceSum when Rescale is unavailable MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Added former fallback mechanism to replace ReduceSum with a 1x1 Conv2D operation when Rescale is not available Change-Id: I4b2a26a80fc2e6de3262217c29094532306417a5 Signed-off-by: Johan Alfvén --- ethosu/regor/compiler/graphir_optimiser.cpp | 90 ++++++++++++++++----- 1 file changed, 69 insertions(+), 21 deletions(-) diff --git a/ethosu/regor/compiler/graphir_optimiser.cpp b/ethosu/regor/compiler/graphir_optimiser.cpp index 436c20d6..4442d402 100644 --- a/ethosu/regor/compiler/graphir_optimiser.cpp +++ b/ethosu/regor/compiler/graphir_optimiser.cpp @@ -1406,27 +1406,75 @@ Operation *GraphIrOptimiser::RewriteReduceSum(Graph *const graph, Operation *con const int64_t zp = ifmConn->quantization.zeroPoints.empty() ? 0 : ifmConn->quantization.zeroPoints[0]; if ( zp != 0 ) { - // Replace ReduceSum (zp != 0) with ReduceSum->Sub(zp): - - // Temporary tensor between ReduceSum and Sub - std::shared_ptr reduceSumTens = ofmConn->tensor->Clone(); - reduceSumTens->SetName(ofmConn->tensor->Name() + "_reducesum"); - reduceSumTens->ChangeType(DataType::Int32); - reduceSumTens->SetStorageShape(ofmConn->shape); - - // Sub op with zero point - auto zpTens = CreateConstTensor("zero_point", DataType::Int32, int(ifmConn->shape.Depth() * zp)); - auto subOp = std::make_shared(OpType::Sub); - subOp->ConnectInput(TensorUsage::IFM, reduceSumTens); - subOp->ConnectInput(TensorUsage::IFM1, zpTens); - subOp->CopyOutput(TensorUsage::OFM, *ofmConn); - subOp->Output(TensorUsage::OFM)->Set(ofmConn->rounding); - RecordOptimisation(operation, subOp.get()); - returnOp = subOp.get(); - - // Connect temporary tensor to reduceSum and remove the zero point - operation->ConnectOutput(TensorUsage::OFM, reduceSumTens).Set(Quantization::Unit()); - ifmConn->quantization.zeroPoints[0] = 0; + if ( _constraints->SupportsFusedRescale(OpType::Sub, TensorUsage::OFM, DataType::Int32, DataType::Int32, + DataType::Int32, DataType::Int32, ofmConn->quantization) ) + { + // Replace ReduceSum (zp != 0) with ReduceSum->Sub(zp): + + // Temporary tensor between ReduceSum and Sub + std::shared_ptr reduceSumTens = ofmConn->tensor->Clone(); + reduceSumTens->SetName(ofmConn->tensor->Name() + "_reducesum"); + reduceSumTens->ChangeType(DataType::Int32); + reduceSumTens->SetStorageShape(ofmConn->shape); + + // Sub op with zero point + auto zpTens = CreateConstTensor("zero_point", DataType::Int32, int(ifmConn->shape.Depth() * zp)); + auto subOp = std::make_shared(OpType::Sub); + subOp->ConnectInput(TensorUsage::IFM, reduceSumTens); + subOp->ConnectInput(TensorUsage::IFM1, zpTens); + subOp->CopyOutput(TensorUsage::OFM, *ofmConn); + subOp->Output(TensorUsage::OFM)->Set(ofmConn->rounding); + RecordOptimisation(operation, subOp.get()); + returnOp = subOp.get(); + + // Connect temporary tensor to reduceSum and remove the zero point + operation->ConnectOutput(TensorUsage::OFM, reduceSumTens).Set(Quantization::Unit()); + ifmConn->quantization.zeroPoints[0] = 0; + } + else + { + // Replace ReduceSum (zp != 0) with 1x1 Conv2D: + // + // 1. Reshape to 3D shape (HWC) where C dimension is the dimension to reduce. + // 2. 1x1 Conv2D (1x1x1xC weights): HxWxC -> HxWx1. + + // Reshape to 4D shape (NHWC) where C dimension is the dimension to reduce + const Shape ifmShape3D = ReshapeTo3D(Shape::PadAxes(ifmConn->shape, 3, 1), {ifmConn->shape.Size() - 2, 1, 1}); + const Shape ifmShape4D = Shape::PadAxes(ifmShape3D, 4, 1); + + // Create an identity 1x1x1xC weights tensor + auto weightsBuffer = std::make_shared(std::vector(ifmShape4D.Depth(), 1)); + auto weightsTens = CreateConstTensor("weights", DataType::Int8, weightsBuffer); + weightsTens->SetStorageShape({1, 1, 1, ifmShape4D.Depth()}); + weightsTens->SetAxisOrder(AxisOrder::OHWI); + auto weightsQuant = ifmConn->quantization; + weightsQuant.quantMin = {IntegerMin(DataType::Int8)}; + weightsQuant.quantMax = {IntegerMax(DataType::Int8)}; + weightsQuant.zeroPoints = {0}; + weightsQuant.scales = {{1, 0}}; // Identity + + // Create an identity bias tensor + auto biasTens = CreateConstTensor("bias", DataType::Int32, 0); + auto biasQuant = ifmConn->quantization; + biasQuant.zeroPoints = {0}; + + // Replace ReduceSum with a 1x1 Conv2D + Kernel kernel({1, 1}, {1, 1}, {1, 1}); + auto convOp = std::make_shared(OpType::Conv2D); + convOp->SetKernel(std::make_unique(kernel)); + convOp->CopyInput(TensorUsage::IFM, *ifmConn); + convOp->Input(TensorUsage::IFM)->Set(ifmShape4D).Set(ifmConn->rounding); + + convOp->ConnectInput(TensorUsage::Weights, weightsTens).Set(weightsQuant); + convOp->ConnectInput(TensorUsage::Scales, biasTens).Set(biasQuant); + convOp->CopyOutput(TensorUsage::OFM, *ofmConn); + convOp->Output(TensorUsage::OFM)->Set(ifmShape4D.WithDepth(1)).Set(ofmConn->rounding); + RecordOptimisation(operation, convOp.get()); + returnOp = convOp.get(); + + // Remove old ReduceSum op + operation->Disconnect(); + } } else if ( ifmConn->shape.Size() > 3 ) { -- GitLab