From ef43c810c765fe8ab6ea04f11998e879969d1d34 Mon Sep 17 00:00:00 2001 From: Alexander Bengtsson Date: Mon, 10 Feb 2025 13:58:23 +0100 Subject: [PATCH] MLBEDSW-10333: Fixes for RewriteReduceSum - Maintain rounding when lowering ReduceSum. - Fix shape of intermediate tensor when ReduceSum is lowered to ReduceSum + Sub (shape should be inherited from the tensor connection) - Keep quantization on the ReduceSum when ReduceSum is lowered to Transpose + ReduceSum. Change-Id: I9653f3638bbda8fa287a3f8a32c4bd3abe4e79ae Signed-off-by: Alexander Bengtsson --- ethosu/regor/compiler/graphir_optimiser.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ethosu/regor/compiler/graphir_optimiser.cpp b/ethosu/regor/compiler/graphir_optimiser.cpp index 1ac16c86..565535ee 100644 --- a/ethosu/regor/compiler/graphir_optimiser.cpp +++ b/ethosu/regor/compiler/graphir_optimiser.cpp @@ -1384,7 +1384,7 @@ Operation *GraphIrOptimiser::RewriteReduceSum(Graph *const graph, Operation *con auto transposeAttr = transposeOp->Attribute(); transposeAttr->perm = {0, 2, 1}; // HCW transposeOp->CopyInput(TensorUsage::IFM, *ifmConn); - transposeOp->Input(TensorUsage::IFM)->Set(ifmShape3D); + transposeOp->Input(TensorUsage::IFM)->Set(ifmShape3D).Set(Quantization::Unit()); transposeOp->ConnectOutput(TensorUsage::OFM, transposeTens); RecordOptimisation(operation, transposeOp.get()); @@ -1392,9 +1392,9 @@ Operation *GraphIrOptimiser::RewriteReduceSum(Graph *const graph, Operation *con auto reduceSumOp = std::make_shared(OpType::ReduceSum); auto reduceAttr = reduceSumOp->Attribute(); reduceAttr->axis = 2; // C - reduceSumOp->ConnectInput(TensorUsage::IFM, transposeTens); + reduceSumOp->ConnectInput(TensorUsage::IFM, transposeTens).Set(ifmConn->quantization).Set(ifmConn->rounding); reduceSumOp->CopyOutput(TensorUsage::OFM, *ofmConn); - reduceSumOp->Output(TensorUsage::OFM)->Set(transposeTens->StorageShape().WithDepth(1)); + reduceSumOp->Output(TensorUsage::OFM)->Set(transposeTens->StorageShape().WithDepth(1)).Set(ofmConn->rounding); RecordOptimisation(operation, reduceSumOp.get()); returnOp = reduceSumOp.get(); @@ -1412,6 +1412,7 @@ Operation *GraphIrOptimiser::RewriteReduceSum(Graph *const graph, Operation *con std::shared_ptr reduceSumTens = ofmConn->tensor->Clone(); reduceSumTens->SetName(ofmConn->tensor->Name() + "_reducesum"); reduceSumTens->ChangeType(DataType::Int32); + reduceSumTens->SetStorageShape(ofmConn->shape); // Sub op with zero point auto zpTens = CreateConstTensor("zero_point", DataType::Int32, int(ifmConn->shape.Depth() * zp)); -- GitLab