diff --git a/ethosu/regor/compiler/graphir_optimiser.cpp b/ethosu/regor/compiler/graphir_optimiser.cpp index d02966321efa27615dde7d29320289b213e481ac..d0c3b5a593362e33f65dda6f4bc8230c749a0f51 100644 --- a/ethosu/regor/compiler/graphir_optimiser.cpp +++ b/ethosu/regor/compiler/graphir_optimiser.cpp @@ -519,6 +519,58 @@ Operation *GraphIrOptimiser::RewriteRescaleInputs(Graph *const, Operation *const return returnOp; } +Operation *GraphIrOptimiser::RemoveRescaleUnsignedAttribute(Graph *const, Operation *const operation) +{ + OpType opType = operation->Type(); + if ( opType == OpType::Rescale ) + { + auto signAttr = operation->Attribute(); + if ( signAttr->input_unsigned ) + { + const auto &ifmConn = operation->Input(TensorUsage::IFM0); + DataType ifmType = ifmConn->tensor->Type(); + auto newIfmType = ifmType & ~unsigned(DataType::Signed); + + // Create a reinterpret OP to reinterpret the input as unsigned + auto reinterpretOp = std::make_shared(OpType::ReinterpretCast); + + // Create an unsigned data type tensor for the reinterpret OP + std::shared_ptr unsignedTensor = ifmConn->tensor->Clone(); + unsignedTensor->ChangeType(newIfmType); + + // Connect the reinterpret OP between the rescale and the rescale IFM + reinterpretOp->CopyInput(TensorUsage::IFM, *ifmConn); + reinterpretOp->ConnectOutput(TensorUsage::OFM, unsignedTensor); + + // Connect the rescale OP input to the unsigned data type tensor + operation->ConnectInput(TensorUsage::IFM, unsignedTensor); + signAttr->input_unsigned = false; + } + if ( signAttr->output_unsigned ) + { + const auto &ofmConn = operation->Output(TensorUsage::OFM); + DataType ofmType = ofmConn->tensor->Type(); + auto newOfmType = ofmType & ~unsigned(DataType::Signed); + + // Create a reinterpret OP to reinterpret the input as unsigned + auto reinterpretOp = std::make_shared(OpType::ReinterpretCast); + + // Create an unsigned data type tensor for the reinterpret OP + std::shared_ptr unsignedTensor = ofmConn->tensor->Clone(); + unsignedTensor->ChangeType(newOfmType); + + // Connect the reinterpret OP between the rescale and the rescale OFM + reinterpretOp->ConnectInput(TensorUsage::IFM, unsignedTensor); + reinterpretOp->CopyOutput(TensorUsage::OFM, *ofmConn); + + // Connect the rescale OP output to the unsigned data type tensor + operation->ConnectOutput(TensorUsage::OFM, unsignedTensor); + signAttr->output_unsigned = false; + } + } + return operation; +} + /* * Lower Rescale into one (or more) 32-bit elementwise MUL operations. * Multipliers are moved to a constant-tensor, while the shift value is keps as ofm-quantization diff --git a/ethosu/regor/compiler/graphir_optimiser.hpp b/ethosu/regor/compiler/graphir_optimiser.hpp index c9f4c28a22611edaa701cf63afd986942e2a3769..31ac099d7b72139addf0898960e2fa734e557072 100644 --- a/ethosu/regor/compiler/graphir_optimiser.hpp +++ b/ethosu/regor/compiler/graphir_optimiser.hpp @@ -50,6 +50,7 @@ private: Operation *RewriteFullyConnected(Graph *const graph, Operation *const operation); Operation *FixupPoolStrides(Graph *const, Operation *const operation); Operation *RewriteRescaleInputs(Graph *const graph, Operation *const operation); + Operation *RemoveRescaleUnsignedAttribute(Graph *const graph, Operation *const operation); Operation *RewriteRescale(Graph *const graph, Operation *const operation); Operation *RewritePad(Graph *const graph, Operation *const operation); Operation *FuseRescale(Graph *const graph, Operation *const operation); @@ -124,13 +125,14 @@ private: { &GraphIrOptimiser::ConvertAttributes, &GraphIrOptimiser::RewriteRescaleInputs, + &GraphIrOptimiser::RemoveRescaleUnsignedAttribute, &GraphIrOptimiser::FuseRescale, // First pass fuse all possible ifm and ofm rescales } }, { {}, { - &GraphIrOptimiser::FuseRescale, // Second pass, fuse any remaining ofm rescales after ifm fusing in first pass + &GraphIrOptimiser::FuseRescale, // Second pass, fuse any remaining ofm rescales after ifm fusing in first pass } }, {