From 1cbc7bf41a5feb025e539192bb05a5002edb1443 Mon Sep 17 00:00:00 2001 From: Max Bergfelt Date: Tue, 15 Apr 2025 09:18:28 +0200 Subject: [PATCH] MLBEDSW-10708: Fix setting Rescale IFM/OFM tensor dtype to unsigned Made sure the dtype of a tensor is changed to unsigned in GraphIR optimizer when the rescale operator has the attribute input_unsigned/output_unsigned set to true. Change-Id: I568638ef0e93252b55a2e748c2ed64d35a208977 Signed-off-by: Max Bergfelt --- ethosu/regor/compiler/graphir_optimiser.cpp | 52 +++++++++++++++++++++ ethosu/regor/compiler/graphir_optimiser.hpp | 4 +- 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/ethosu/regor/compiler/graphir_optimiser.cpp b/ethosu/regor/compiler/graphir_optimiser.cpp index d0296632..d0c3b5a5 100644 --- a/ethosu/regor/compiler/graphir_optimiser.cpp +++ b/ethosu/regor/compiler/graphir_optimiser.cpp @@ -519,6 +519,58 @@ Operation *GraphIrOptimiser::RewriteRescaleInputs(Graph *const, Operation *const return returnOp; } +Operation *GraphIrOptimiser::RemoveRescaleUnsignedAttribute(Graph *const, Operation *const operation) +{ + OpType opType = operation->Type(); + if ( opType == OpType::Rescale ) + { + auto signAttr = operation->Attribute(); + if ( signAttr->input_unsigned ) + { + const auto &ifmConn = operation->Input(TensorUsage::IFM0); + DataType ifmType = ifmConn->tensor->Type(); + auto newIfmType = ifmType & ~unsigned(DataType::Signed); + + // Create a reinterpret OP to reinterpret the input as unsigned + auto reinterpretOp = std::make_shared(OpType::ReinterpretCast); + + // Create an unsigned data type tensor for the reinterpret OP + std::shared_ptr unsignedTensor = ifmConn->tensor->Clone(); + unsignedTensor->ChangeType(newIfmType); + + // Connect the reinterpret OP between the rescale and the rescale IFM + reinterpretOp->CopyInput(TensorUsage::IFM, *ifmConn); + reinterpretOp->ConnectOutput(TensorUsage::OFM, unsignedTensor); + + // Connect the rescale OP input to the unsigned data type tensor + operation->ConnectInput(TensorUsage::IFM, unsignedTensor); + signAttr->input_unsigned = false; + } + if ( signAttr->output_unsigned ) + { + const auto &ofmConn = operation->Output(TensorUsage::OFM); + DataType ofmType = ofmConn->tensor->Type(); + auto newOfmType = ofmType & ~unsigned(DataType::Signed); + + // Create a reinterpret OP to reinterpret the input as unsigned + auto reinterpretOp = std::make_shared(OpType::ReinterpretCast); + + // Create an unsigned data type tensor for the reinterpret OP + std::shared_ptr unsignedTensor = ofmConn->tensor->Clone(); + unsignedTensor->ChangeType(newOfmType); + + // Connect the reinterpret OP between the rescale and the rescale OFM + reinterpretOp->ConnectInput(TensorUsage::IFM, unsignedTensor); + reinterpretOp->CopyOutput(TensorUsage::OFM, *ofmConn); + + // Connect the rescale OP output to the unsigned data type tensor + operation->ConnectOutput(TensorUsage::OFM, unsignedTensor); + signAttr->output_unsigned = false; + } + } + return operation; +} + /* * Lower Rescale into one (or more) 32-bit elementwise MUL operations. * Multipliers are moved to a constant-tensor, while the shift value is keps as ofm-quantization diff --git a/ethosu/regor/compiler/graphir_optimiser.hpp b/ethosu/regor/compiler/graphir_optimiser.hpp index c9f4c28a..31ac099d 100644 --- a/ethosu/regor/compiler/graphir_optimiser.hpp +++ b/ethosu/regor/compiler/graphir_optimiser.hpp @@ -50,6 +50,7 @@ private: Operation *RewriteFullyConnected(Graph *const graph, Operation *const operation); Operation *FixupPoolStrides(Graph *const, Operation *const operation); Operation *RewriteRescaleInputs(Graph *const graph, Operation *const operation); + Operation *RemoveRescaleUnsignedAttribute(Graph *const graph, Operation *const operation); Operation *RewriteRescale(Graph *const graph, Operation *const operation); Operation *RewritePad(Graph *const graph, Operation *const operation); Operation *FuseRescale(Graph *const graph, Operation *const operation); @@ -124,13 +125,14 @@ private: { &GraphIrOptimiser::ConvertAttributes, &GraphIrOptimiser::RewriteRescaleInputs, + &GraphIrOptimiser::RemoveRescaleUnsignedAttribute, &GraphIrOptimiser::FuseRescale, // First pass fuse all possible ifm and ofm rescales } }, { {}, { - &GraphIrOptimiser::FuseRescale, // Second pass, fuse any remaining ofm rescales after ifm fusing in first pass + &GraphIrOptimiser::FuseRescale, // Second pass, fuse any remaining ofm rescales after ifm fusing in first pass } }, { -- GitLab