diff --git a/ethosu/regor/compiler/graphir_optimiser.cpp b/ethosu/regor/compiler/graphir_optimiser.cpp index 24bab1b997f34e88dafa774b3065611fb22ed34b..aafa001773b36527371801b187eb27a2844f58c7 100644 --- a/ethosu/regor/compiler/graphir_optimiser.cpp +++ b/ethosu/regor/compiler/graphir_optimiser.cpp @@ -833,31 +833,34 @@ Operation *GraphIrOptimiser::FuseRescale(Graph *const graph, Operation *const op return returnOp; } + // Convert scales to have 0 shift if possible, since this can improve fusing for Ethos-U55/65 + auto ConvertedScales = [](const TensorConnection *conn) + { + auto scales = conn->quantization.scales; + for ( auto &qs : scales ) + { + if ( qs.shift > 0 && qs.shift < 31 && (qs.scale % (1 << qs.shift)) == 0 ) + { + qs = {(qs.scale >> qs.shift), 0}; + } + } + return scales; + }; + // Check if there is only one consumer of the output of the rescale and try to fuse to that operation. // Note: For input fusing we cannot have an output zero point on the Rescale operation (since the // zero point is applied before scaling on inputs), however input zero point is fine. if ( ofmConn->tensor->Readers().size() == 1 && ofmConn->quantization.zeroPoints == Quantization::Unit().zeroPoints ) { - // Copies quantization information from ifm connection and (converted) scales from ofm connection, - // since these are the scales we want to apply. - auto CopyQuantizationAndConvertScales = [](const TensorConnection *ic, const TensorConnection *oc) - { - auto result = ic->quantization; - result.scales = oc->quantization.scales; - // Convert scales to have 0 shift if possible, since this can - // improve fusing for Ethos-U55/65 - for ( auto &qs : result.scales ) - { - if ( qs.shift > 0 && qs.shift < 31 && (qs.scale % (1 << qs.shift)) == 0 ) - { - qs = {(qs.scale >> qs.shift), 0}; - } - } - return result; - }; // Propagate rescaling to input of next op auto consumer = ofmConn->tensor->Readers().front(); - auto ifmQuant = CopyQuantizationAndConvertScales(ifmConn, ofmConn); + + // Copy quantization information from ifm connection and (converted) scales from ofm connection, + // since these are the scales we want to apply. + auto ifmQuant = ifmConn->quantization; + // Normalize scales to shift 0 if possible + ifmQuant.scales = ConvertedScales(ofmConn); + for ( auto ifm : consumer->Inputs().pairs() ) { if ( ifm.second.tensor == ofmConn->tensor ) @@ -883,8 +886,9 @@ Operation *GraphIrOptimiser::FuseRescale(Graph *const graph, Operation *const op if ( otherProducer && otherProducer->Type() == OpType::Rescale ) { // Check if the other ifm rescale can be fused - auto otherIfmQuant = CopyQuantizationAndConvertScales( - otherProducer->Input(TensorUsage::IFM), otherProducer->Output(TensorUsage::OFM)); + auto otherIfmQuant = otherProducer->Input(TensorUsage::IFM)->quantization; + otherIfmQuant.scales = ConvertedScales(otherProducer->Output(TensorUsage::OFM)); + if ( otherIfmCon->quantization.EqualScales(Quantization::Unit()) && _constraints->SupportsFusedRescale(consumer->Type(), TensorUsage::IFM, otherProducer->IFM(0)->Type(), otherProducer->OFM()->Type(), otherIfmQuant) ) @@ -910,10 +914,15 @@ Operation *GraphIrOptimiser::FuseRescale(Graph *const graph, Operation *const op // is only one producer of the input to the rescale operation. If this input has no zero point // adjustment and the producers output has unit scaling, it might be possible to fuse this rescale to // the producers output if the constraints of the architecture allows it + + // Normalize scales to shift 0 if possible + auto ofmQuant = ofmConn->quantization; + ofmQuant.scales = ConvertedScales(ofmConn); + if ( returnOp == operation && producer && producer->Output(TensorUsage::OFM)->quantization.EqualScales(Quantization::Unit()) && ifmConn->quantization.zeroPoints == Quantization::Unit().zeroPoints && - _constraints->SupportsFusedRescale(producer->Type(), TensorUsage::OFM, producer->IFM(0)->Type(), - ofmConn->tensor->Type(), ofmConn->quantization) ) + _constraints->SupportsFusedRescale( + producer->Type(), TensorUsage::OFM, producer->IFM(0)->Type(), ofmConn->tensor->Type(), ofmQuant) ) { // Propagate rescaling to output of previous op producer->CopyOutput(TensorUsage::OFM, *ofmConn);