From d37febc1715edf0d236c2ff555739a8a9aadcf9a Mon Sep 17 00:00:00 2001 From: Johan Gunnarsson Date: Tue, 17 Jun 2025 17:36:07 +0200 Subject: [PATCH] MLBEDSW-10905: Keep original shapes in FuseRescale CopyInput/CopyOutput replaces the TensorConnection's shape with the tensor's storage shape. This is undesired in case the original network had a RESHAPE between our op and the RESCALE. In that case we want to preserve the original shape. Signed-off-by: Johan Gunnarsson Change-Id: I1e78ce6ec03f71c06a35bb538d3388090178241c --- ethosu/regor/compiler/graphir_optimiser.cpp | 4 +- ethosu/regor/test/test_graphir_optimiser.cpp | 94 ++++++++++++++++++++ 2 files changed, 96 insertions(+), 2 deletions(-) diff --git a/ethosu/regor/compiler/graphir_optimiser.cpp b/ethosu/regor/compiler/graphir_optimiser.cpp index 1e83461a..efa5e3c4 100644 --- a/ethosu/regor/compiler/graphir_optimiser.cpp +++ b/ethosu/regor/compiler/graphir_optimiser.cpp @@ -1264,7 +1264,7 @@ Operation *GraphIrOptimiser::FuseRescale(Graph *const graph, Operation *const op // avoid performing this fuse operation. if ( !sameType ) break; } - consumer->CopyInput(ifm.first, *ifmConn); + ReplaceConsumerInput(nullptr, ofmConn->tensor->Readers(), ofmConn->tensor.get(), ifmConn->tensor); ifm.second.quantization = ifmQuant; consumer->Input(ifm.first)->Set(ofmConn->rounding); returnOp = consumer.get(); @@ -1290,7 +1290,7 @@ Operation *GraphIrOptimiser::FuseRescale(Graph *const graph, Operation *const op ofmConn->tensor->Type(), producer->IFM(0)->Type(), producer->OFM()->Type(), ofmQuant) ) { // Propagate rescaling to output of previous op - producer->CopyOutput(TensorUsage::OFM, *ofmConn); + ReplaceProducerOutput(ifmConn->tensor->Writers(), ifmConn->tensor.get(), ofmConn->tensor); producer->Output(TensorUsage::OFM)->Set(ofmConn->rounding).Set(ofmQuant); returnOp = producer.get(); } diff --git a/ethosu/regor/test/test_graphir_optimiser.cpp b/ethosu/regor/test/test_graphir_optimiser.cpp index f0c64341..3adbdfa7 100644 --- a/ethosu/regor/test/test_graphir_optimiser.cpp +++ b/ethosu/regor/test/test_graphir_optimiser.cpp @@ -319,3 +319,97 @@ TEST_CASE("test_graphir_optimiser - replace pad by explicit padding") REQUIRE(padding.Near() == 0); REQUIRE(padding.Far() == 0); } + +TEST_CASE("test_graphir_optimiser - fuse rescale with reshape, before") +{ + // Create arch + auto arch = CreateArchDefault(); + std::string err = "noerror"; + arch->CheckConfiguration(err); + REQUIRE(err == "noerror"); + + std::vector> ops; + auto input = CreateTensor("INPUT", Shape(1, 8, 2, 1), DataType::Int8); + auto mulParam = CreateTensor("MUL_PARAM", Shape(1, 1), DataType::Int32, 1073741824); + auto shiftParam = CreateTensor("SHIFT_PARAM", Shape(1, 1), DataType::Int8, 31); + auto rescaleOfm = CreateTensor("RESCALE_OFM", Shape(1, 8, 2, 1), DataType::Int8); + auto reshapeOfm = CreateTensor("RESHAPE_OFM", Shape(1, 4, 4, 1), DataType::Int8); + auto absOfm = CreateTensor("ABS_OFM", Shape(1, 4, 4, 1), DataType::Int8); + + // Create a RESCALE-RESHAPE-ABS graph + ops.push_back(CreateOperation(OpType::Rescale, TensorUsage::IFM, input, TensorUsage::Params0, mulParam, + TensorUsage::Params1, shiftParam, TensorUsage::OFM, rescaleOfm)); + auto *rescaleAttr = ops.back()->Attribute(); + rescaleAttr->double_round = false; + rescaleAttr->per_channel = false; + rescaleAttr->scale32 = true; + auto *signAttr = ops.back()->Attribute(); + signAttr->input_unsigned = false; + signAttr->output_unsigned = false; + ops.push_back(CreateOperation(OpType::Reshape, TensorUsage::IFM, rescaleOfm, TensorUsage::OFM, reshapeOfm)); + ops.push_back(CreateOperation(OpType::Abs, TensorUsage::IFM, reshapeOfm, TensorUsage::OFM, absOfm)); + + auto graph = CreateGraph(ops); + + GraphOptimiserOptions options; + auto optimiser = GraphOptimiser::MakeGraphOptimiser(graph->Notation(), arch.get(), options, nullptr); + + optimiser->Process(graph.get()); + + std::vector allOps; + graph->GetAllOperations(allOps); + REQUIRE(allOps.size() == 1); + REQUIRE(allOps[0]->Type() == OpType::Abs); + REQUIRE(allOps[0]->Input(TensorUsage::IFM)->SliceShape() == Shape(1, 4, 4, 1)); + REQUIRE(allOps[0]->Input(TensorUsage::IFM)->quantization.zeroPoints[0] == 0); + REQUIRE(allOps[0]->Input(TensorUsage::IFM)->quantization.scales[0].scale == 1073741824); + REQUIRE(allOps[0]->Input(TensorUsage::IFM)->quantization.scales[0].shift == 31); + REQUIRE(allOps[0]->Output(TensorUsage::OFM)->SliceShape() == Shape(1, 4, 4, 1)); +} + +TEST_CASE("test_graphir_optimiser - fuse rescale with reshape, after") +{ + // Create arch + auto arch = CreateArchDefault(); + std::string err = "noerror"; + arch->CheckConfiguration(err); + REQUIRE(err == "noerror"); + + std::vector> ops; + auto input = CreateTensor("INPUT", Shape(1, 4, 4, 1), DataType::Int8); + auto absOfm = CreateTensor("ABS_OFM", Shape(1, 4, 4, 1), DataType::Int8); + auto reshapeOfm = CreateTensor("RESHAPE_OFM", Shape(1, 8, 2, 1), DataType::Int8); + auto mulParam = CreateTensor("MUL_PARAM", Shape(1, 1), DataType::Int32, 1073741824); + auto shiftParam = CreateTensor("SHIFT_PARAM", Shape(1, 1), DataType::Int8, 31); + auto rescaleOfm = CreateTensor("RESCALE_OFM", Shape(1, 8, 2, 1), DataType::Int8); + + // Create a ABS-RESHAPE-RESCALE graph + ops.push_back(CreateOperation(OpType::Abs, TensorUsage::IFM, input, TensorUsage::OFM, absOfm)); + ops.push_back(CreateOperation(OpType::Reshape, TensorUsage::IFM, absOfm, TensorUsage::OFM, reshapeOfm)); + ops.push_back(CreateOperation(OpType::Rescale, TensorUsage::IFM, reshapeOfm, TensorUsage::Params0, mulParam, + TensorUsage::Params1, shiftParam, TensorUsage::OFM, rescaleOfm)); + auto *rescaleAttr = ops.back()->Attribute(); + rescaleAttr->double_round = false; + rescaleAttr->per_channel = false; + rescaleAttr->scale32 = true; + auto *signAttr = ops.back()->Attribute(); + signAttr->input_unsigned = false; + signAttr->output_unsigned = false; + + auto graph = CreateGraph(ops); + + GraphOptimiserOptions options; + auto optimiser = GraphOptimiser::MakeGraphOptimiser(graph->Notation(), arch.get(), options, nullptr); + + optimiser->Process(graph.get()); + + std::vector allOps; + graph->GetAllOperations(allOps); + REQUIRE(allOps.size() == 1); + REQUIRE(allOps[0]->Type() == OpType::Abs); + REQUIRE(allOps[0]->Input(TensorUsage::IFM)->SliceShape() == Shape(1, 4, 4, 1)); + REQUIRE(allOps[0]->Output(TensorUsage::OFM)->SliceShape() == Shape(1, 4, 4, 1)); + REQUIRE(allOps[0]->Output(TensorUsage::OFM)->quantization.zeroPoints[0] == 0); + REQUIRE(allOps[0]->Output(TensorUsage::OFM)->quantization.scales[0].scale == 1073741824); + REQUIRE(allOps[0]->Output(TensorUsage::OFM)->quantization.scales[0].shift == 31); +} -- GitLab