diff --git a/ethosu/regor/compiler/graphir_optimiser.cpp b/ethosu/regor/compiler/graphir_optimiser.cpp index 1ac16c86691de4af9936ce35925019a92fd88289..565535eee6b48d263e3b7d5804528a09937f4595 100644 --- a/ethosu/regor/compiler/graphir_optimiser.cpp +++ b/ethosu/regor/compiler/graphir_optimiser.cpp @@ -1384,7 +1384,7 @@ Operation *GraphIrOptimiser::RewriteReduceSum(Graph *const graph, Operation *con auto transposeAttr = transposeOp->Attribute(); transposeAttr->perm = {0, 2, 1}; // HCW transposeOp->CopyInput(TensorUsage::IFM, *ifmConn); - transposeOp->Input(TensorUsage::IFM)->Set(ifmShape3D); + transposeOp->Input(TensorUsage::IFM)->Set(ifmShape3D).Set(Quantization::Unit()); transposeOp->ConnectOutput(TensorUsage::OFM, transposeTens); RecordOptimisation(operation, transposeOp.get()); @@ -1392,9 +1392,9 @@ Operation *GraphIrOptimiser::RewriteReduceSum(Graph *const graph, Operation *con auto reduceSumOp = std::make_shared(OpType::ReduceSum); auto reduceAttr = reduceSumOp->Attribute(); reduceAttr->axis = 2; // C - reduceSumOp->ConnectInput(TensorUsage::IFM, transposeTens); + reduceSumOp->ConnectInput(TensorUsage::IFM, transposeTens).Set(ifmConn->quantization).Set(ifmConn->rounding); reduceSumOp->CopyOutput(TensorUsage::OFM, *ofmConn); - reduceSumOp->Output(TensorUsage::OFM)->Set(transposeTens->StorageShape().WithDepth(1)); + reduceSumOp->Output(TensorUsage::OFM)->Set(transposeTens->StorageShape().WithDepth(1)).Set(ofmConn->rounding); RecordOptimisation(operation, reduceSumOp.get()); returnOp = reduceSumOp.get(); @@ -1412,6 +1412,7 @@ Operation *GraphIrOptimiser::RewriteReduceSum(Graph *const graph, Operation *con std::shared_ptr reduceSumTens = ofmConn->tensor->Clone(); reduceSumTens->SetName(ofmConn->tensor->Name() + "_reducesum"); reduceSumTens->ChangeType(DataType::Int32); + reduceSumTens->SetStorageShape(ofmConn->shape); // Sub op with zero point auto zpTens = CreateConstTensor("zero_point", DataType::Int32, int(ifmConn->shape.Depth() * zp));