From e989a504955781ae07101464995b0d168d7411cb Mon Sep 17 00:00:00 2001 From: Jacob Bohlin Date: Wed, 7 May 2025 13:47:09 +0100 Subject: [PATCH] MLBEDSW-10681 Implement zero-point correction required for LSTM int16 Change-Id: I7bcef5bc787a7e00d7dee820a79bc451e1a97494 Signed-off-by: Jacob Bohlin --- .../regor/compiler/tflite_graph_optimiser.cpp | 122 ++++++++++++++++++ .../regor/compiler/tflite_graph_optimiser.hpp | 4 + .../tflite/tflite_supported_operators.cpp | 3 +- 3 files changed, 128 insertions(+), 1 deletion(-) diff --git a/ethosu/regor/compiler/tflite_graph_optimiser.cpp b/ethosu/regor/compiler/tflite_graph_optimiser.cpp index e3ecd8c7..5e722813 100644 --- a/ethosu/regor/compiler/tflite_graph_optimiser.cpp +++ b/ethosu/regor/compiler/tflite_graph_optimiser.cpp @@ -2643,6 +2643,128 @@ Operation *TFLiteGraphOptimiser::ConvertZeroPoint(Graph *const graph, Operation return operation; } +// The reference has some special cases for allowing asymmetric int16 quantization, e.g. LSTM. +// In the lowering of these ops the compiler can create other operators which have inherited said +// quantization which may require legalization depending on which hardware is targeted. +Operation *TFLiteGraphOptimiser::LegalizeAsymmetricQuantization(Graph *const graph, Operation *const operation) +{ + auto returnOp = operation; + OpType opType = operation->Type(); + TensorConnection *ifmConn = operation->Input(TensorUsage::IFM); + if ( ifmConn->quantization.zeroPoints.size() == 0 ) + { + return returnOp; + } + + auto ifmZeroPoint = ifmConn->quantization.zeroPoints[0]; + DataType ifmDType = ifmConn->tensor->Type(); + if ( !_constraints->SupportedZeroPoint(ifmZeroPoint, TensorUsage::IFM, ifmDType, opType) ) + { + assert(ifmConn->quantization.zeroPoints.size() == 1); + TensorConnection *ofmConn = operation->Output(TensorUsage::OFM); + if ( opType == OpType::MemoryCopy || opType == OpType::Slice ) + { + // Expected to have the same quantization which means no data is modified and therefore + // the zero-point can simply be removed. + assert(ifmConn->quantization == ofmConn->quantization); + ifmConn->quantization.zeroPoints.clear(); + ofmConn->quantization.zeroPoints.clear(); + } + else + { + assert(opType == OpType::FullyConnected && "Unexpected OpType"); + // Subtract the weight-adjusted ifm zero-point after the FullyConnected operation. + // Rationale (note, the '*' are vector products): + // (ifm - zp) * w == ifm * w - zp * w + TensorConnection *weightConn = operation->Input(TensorUsage::Weights); + auto weights = weightConn->tensor->View().Values(weightConn->tensor->Type()); + + // Calculate new zero points by doing a vector product between the broadcasted zero-point and the weights + std::vector weightAdjustedZeroPoints; + for ( int ic = 0; ic < weightConn->shape[0]; ic++ ) + { + int value = 0; + for ( int oc = 0; oc < weightConn->shape[-1]; oc++ ) + { + value += ifmZeroPoint * weights[{ic, oc}]; + } + weightAdjustedZeroPoints.emplace_back(value); + } + // Replicate the weight adjusted zero-points for every batch + auto ofmShape = ofmConn->shape; + std::vector newZeroPoints; + newZeroPoints.reserve(ofmShape.Elements()); + for ( int n = 0; n < ofmShape[0]; n++ ) + { + newZeroPoints.insert(newZeroPoints.end(), weightAdjustedZeroPoints.begin(), weightAdjustedZeroPoints.end()); + } + + // Create zero-point tensor and higher precision intermediate tensor + auto zeroPointTens = CreateConstTensor( + "zeroPoints", DataType::Int32, std::make_shared(std::move(newZeroPoints)), &ofmShape); + auto intermediateTensor = std::make_shared( + fmt::format("{0}_zp_corrected", ofmConn->tensor->Name()), DataType::Int32, ofmShape); + + // Compute the OFM quantization for the zero-point subtraction + float ifmScale = float(ifmConn->quantization.scales[0].Dequantize()); + float ofmScale = float(ofmConn->quantization.scales[0].Dequantize()); + float weightScale = float(weightConn->quantization.scales[0].Dequantize()); + Quantization subOfmQuant; + subOfmQuant.scales = {QuantizedScale(double(ifmScale * weightScale) / double(ofmScale), true)}; + + // Create zero-point subtract op and set quantization parameters + const Quantization &unitQuant = Quantization::Unit(); + auto zpCorrectOp = std::make_shared(OpType::Sub); + zpCorrectOp->ConnectInput(TensorUsage::IFM, intermediateTensor).Set(ofmShape).Set(unitQuant); + zpCorrectOp->ConnectInput(TensorUsage::IFM1, zeroPointTens).Set(ofmShape).Set(unitQuant); + zpCorrectOp->ConnectOutput(TensorUsage::OFM, ofmConn->tensor).Set(ofmShape).Set(subOfmQuant); + + operation->ConnectOutput(TensorUsage::OFM, intermediateTensor).Set(ofmShape).Set(unitQuant); + ifmConn->quantization = unitQuant; + weightConn->quantization = unitQuant; + + RecordOptimisation(operation, zpCorrectOp.get()); + returnOp = zpCorrectOp.get(); + } + } + + TensorConnection *ofmConn = operation->Output(TensorUsage::OFM); + if ( ofmConn->quantization.zeroPoints.size() == 0 ) + { + return returnOp; + } + + auto ofmZeroPoint = ofmConn->quantization.zeroPoints[0]; + DataType ofmDType = ofmConn->tensor->Type(); + if ( !_constraints->SupportedZeroPoint(ofmZeroPoint, TensorUsage::OFM, ofmDType, opType) ) + { + assert(opType == OpType::Mul && "Unexcpected OpType"); + + Quantization unitQuant = Quantization::Unit(); + unitQuant.type = QuantizationType::TFLITE; + auto ofmQuantNoZP = ofmConn->quantization; + ofmQuantNoZP.zeroPoints = {0}; + + // Create zero-point tensor and higher precision intermediate tensor + auto zeroPointTens = CreateConstTensor("zeroPoint", ofmZeroPoint); + auto intermediateTensor = std::make_shared( + fmt::format("{0}_zp_corrected", ifmConn->tensor->Name()), DataType::Int32, ifmConn->SliceShape()); + + // Create zero-point subtract op and set quantization parameters + auto zpCorrectOp = std::make_shared(OpType::Sub); + zpCorrectOp->ConnectInput(TensorUsage::IFM, intermediateTensor).Set(ifmConn->shape).Set(ofmQuantNoZP); + zpCorrectOp->ConnectInput(TensorUsage::IFM1, zeroPointTens).Set(unitQuant); + zpCorrectOp->ConnectOutput(TensorUsage::OFM, ofmConn->tensor).Set(ofmConn->shape).Set(ofmQuantNoZP); + + operation->ConnectOutput(TensorUsage::OFM, intermediateTensor).Set(ofmConn->shape).Set(ofmQuantNoZP); + + RecordOptimisation(operation, zpCorrectOp.get()); + returnOp = zpCorrectOp.get(); + } + + return returnOp; +} + // Return a slice of a tensor template static std::shared_ptr diff --git a/ethosu/regor/compiler/tflite_graph_optimiser.hpp b/ethosu/regor/compiler/tflite_graph_optimiser.hpp index 30710470..088fe3ef 100644 --- a/ethosu/regor/compiler/tflite_graph_optimiser.hpp +++ b/ethosu/regor/compiler/tflite_graph_optimiser.hpp @@ -181,6 +181,9 @@ private: // Rewrites zero point as expected by reference Operation *ConvertZeroPoint(Graph *const graph, Operation *const operation); + // Legalizes asymmetric quantization, i.e. non zero zero-point, if required by hardware + Operation *LegalizeAsymmetricQuantization(Graph *const graph, Operation *const operation); + public: // The graph optimisation steps. // Order matters, array of rewrites processed in order. @@ -276,6 +279,7 @@ public: {}, { &TFLiteGraphOptimiser::ConvertZeroPoint, + &TFLiteGraphOptimiser::LegalizeAsymmetricQuantization, } }, { diff --git a/ethosu/regor/tflite/tflite_supported_operators.cpp b/ethosu/regor/tflite/tflite_supported_operators.cpp index f30933c1..1618a396 100644 --- a/ethosu/regor/tflite/tflite_supported_operators.cpp +++ b/ethosu/regor/tflite/tflite_supported_operators.cpp @@ -276,7 +276,8 @@ bool TfLiteSupportedOperators::ConstraintZeroPoints(const Operation *op) { OpType opType = op->Type(); // zeroPoints are ignored for the following operations to align with reference - if ( opType == OpType::AvgPool || opType == OpType::Resize || opType == OpType::CLZ || opType == OpType::SHL || opType == OpType::Div ) + if ( opType == OpType::AvgPool || opType == OpType::Resize || opType == OpType::CLZ || opType == OpType::SHL || + opType == OpType::Div || opType == OpType::UnidirectionalSequenceLstm ) { return true; } -- GitLab