diff --git a/ethosu/regor/compiler/graphir_optimiser.cpp b/ethosu/regor/compiler/graphir_optimiser.cpp index 1e83461a71aceddf066056be930b592959121794..efa5e3c43fd5eb09153f93c388f961d39eb92e4b 100644 --- a/ethosu/regor/compiler/graphir_optimiser.cpp +++ b/ethosu/regor/compiler/graphir_optimiser.cpp @@ -1264,7 +1264,7 @@ Operation *GraphIrOptimiser::FuseRescale(Graph *const graph, Operation *const op // avoid performing this fuse operation. if ( !sameType ) break; } - consumer->CopyInput(ifm.first, *ifmConn); + ReplaceConsumerInput(nullptr, ofmConn->tensor->Readers(), ofmConn->tensor.get(), ifmConn->tensor); ifm.second.quantization = ifmQuant; consumer->Input(ifm.first)->Set(ofmConn->rounding); returnOp = consumer.get(); @@ -1290,7 +1290,7 @@ Operation *GraphIrOptimiser::FuseRescale(Graph *const graph, Operation *const op ofmConn->tensor->Type(), producer->IFM(0)->Type(), producer->OFM()->Type(), ofmQuant) ) { // Propagate rescaling to output of previous op - producer->CopyOutput(TensorUsage::OFM, *ofmConn); + ReplaceProducerOutput(ifmConn->tensor->Writers(), ifmConn->tensor.get(), ofmConn->tensor); producer->Output(TensorUsage::OFM)->Set(ofmConn->rounding).Set(ofmQuant); returnOp = producer.get(); } diff --git a/ethosu/regor/test/test_graphir_optimiser.cpp b/ethosu/regor/test/test_graphir_optimiser.cpp index f0c6434148081a582c8485703f66c31c41a911a0..3adbdfa72e3e394e061c690d7a869286d19d0e52 100644 --- a/ethosu/regor/test/test_graphir_optimiser.cpp +++ b/ethosu/regor/test/test_graphir_optimiser.cpp @@ -319,3 +319,97 @@ TEST_CASE("test_graphir_optimiser - replace pad by explicit padding") REQUIRE(padding.Near() == 0); REQUIRE(padding.Far() == 0); } + +TEST_CASE("test_graphir_optimiser - fuse rescale with reshape, before") +{ + // Create arch + auto arch = CreateArchDefault(); + std::string err = "noerror"; + arch->CheckConfiguration(err); + REQUIRE(err == "noerror"); + + std::vector> ops; + auto input = CreateTensor("INPUT", Shape(1, 8, 2, 1), DataType::Int8); + auto mulParam = CreateTensor("MUL_PARAM", Shape(1, 1), DataType::Int32, 1073741824); + auto shiftParam = CreateTensor("SHIFT_PARAM", Shape(1, 1), DataType::Int8, 31); + auto rescaleOfm = CreateTensor("RESCALE_OFM", Shape(1, 8, 2, 1), DataType::Int8); + auto reshapeOfm = CreateTensor("RESHAPE_OFM", Shape(1, 4, 4, 1), DataType::Int8); + auto absOfm = CreateTensor("ABS_OFM", Shape(1, 4, 4, 1), DataType::Int8); + + // Create a RESCALE-RESHAPE-ABS graph + ops.push_back(CreateOperation(OpType::Rescale, TensorUsage::IFM, input, TensorUsage::Params0, mulParam, + TensorUsage::Params1, shiftParam, TensorUsage::OFM, rescaleOfm)); + auto *rescaleAttr = ops.back()->Attribute(); + rescaleAttr->double_round = false; + rescaleAttr->per_channel = false; + rescaleAttr->scale32 = true; + auto *signAttr = ops.back()->Attribute(); + signAttr->input_unsigned = false; + signAttr->output_unsigned = false; + ops.push_back(CreateOperation(OpType::Reshape, TensorUsage::IFM, rescaleOfm, TensorUsage::OFM, reshapeOfm)); + ops.push_back(CreateOperation(OpType::Abs, TensorUsage::IFM, reshapeOfm, TensorUsage::OFM, absOfm)); + + auto graph = CreateGraph(ops); + + GraphOptimiserOptions options; + auto optimiser = GraphOptimiser::MakeGraphOptimiser(graph->Notation(), arch.get(), options, nullptr); + + optimiser->Process(graph.get()); + + std::vector allOps; + graph->GetAllOperations(allOps); + REQUIRE(allOps.size() == 1); + REQUIRE(allOps[0]->Type() == OpType::Abs); + REQUIRE(allOps[0]->Input(TensorUsage::IFM)->SliceShape() == Shape(1, 4, 4, 1)); + REQUIRE(allOps[0]->Input(TensorUsage::IFM)->quantization.zeroPoints[0] == 0); + REQUIRE(allOps[0]->Input(TensorUsage::IFM)->quantization.scales[0].scale == 1073741824); + REQUIRE(allOps[0]->Input(TensorUsage::IFM)->quantization.scales[0].shift == 31); + REQUIRE(allOps[0]->Output(TensorUsage::OFM)->SliceShape() == Shape(1, 4, 4, 1)); +} + +TEST_CASE("test_graphir_optimiser - fuse rescale with reshape, after") +{ + // Create arch + auto arch = CreateArchDefault(); + std::string err = "noerror"; + arch->CheckConfiguration(err); + REQUIRE(err == "noerror"); + + std::vector> ops; + auto input = CreateTensor("INPUT", Shape(1, 4, 4, 1), DataType::Int8); + auto absOfm = CreateTensor("ABS_OFM", Shape(1, 4, 4, 1), DataType::Int8); + auto reshapeOfm = CreateTensor("RESHAPE_OFM", Shape(1, 8, 2, 1), DataType::Int8); + auto mulParam = CreateTensor("MUL_PARAM", Shape(1, 1), DataType::Int32, 1073741824); + auto shiftParam = CreateTensor("SHIFT_PARAM", Shape(1, 1), DataType::Int8, 31); + auto rescaleOfm = CreateTensor("RESCALE_OFM", Shape(1, 8, 2, 1), DataType::Int8); + + // Create a ABS-RESHAPE-RESCALE graph + ops.push_back(CreateOperation(OpType::Abs, TensorUsage::IFM, input, TensorUsage::OFM, absOfm)); + ops.push_back(CreateOperation(OpType::Reshape, TensorUsage::IFM, absOfm, TensorUsage::OFM, reshapeOfm)); + ops.push_back(CreateOperation(OpType::Rescale, TensorUsage::IFM, reshapeOfm, TensorUsage::Params0, mulParam, + TensorUsage::Params1, shiftParam, TensorUsage::OFM, rescaleOfm)); + auto *rescaleAttr = ops.back()->Attribute(); + rescaleAttr->double_round = false; + rescaleAttr->per_channel = false; + rescaleAttr->scale32 = true; + auto *signAttr = ops.back()->Attribute(); + signAttr->input_unsigned = false; + signAttr->output_unsigned = false; + + auto graph = CreateGraph(ops); + + GraphOptimiserOptions options; + auto optimiser = GraphOptimiser::MakeGraphOptimiser(graph->Notation(), arch.get(), options, nullptr); + + optimiser->Process(graph.get()); + + std::vector allOps; + graph->GetAllOperations(allOps); + REQUIRE(allOps.size() == 1); + REQUIRE(allOps[0]->Type() == OpType::Abs); + REQUIRE(allOps[0]->Input(TensorUsage::IFM)->SliceShape() == Shape(1, 4, 4, 1)); + REQUIRE(allOps[0]->Output(TensorUsage::OFM)->SliceShape() == Shape(1, 4, 4, 1)); + REQUIRE(allOps[0]->Output(TensorUsage::OFM)->quantization.zeroPoints[0] == 0); + REQUIRE(allOps[0]->Output(TensorUsage::OFM)->quantization.scales[0].scale == 1073741824); + REQUIRE(allOps[0]->Output(TensorUsage::OFM)->quantization.scales[0].shift == 31); +}