From b15d154e7c0430175045f878035f66d80df547e5 Mon Sep 17 00:00:00 2001 From: Alexander Bengtsson Date: Thu, 24 Apr 2025 10:25:22 +0200 Subject: [PATCH 1/2] MLBEDSW-10735: Add supported-operator checks for ZeroPoints Change-Id: Ic34a96801c1aaecf8b174e6b11b7b26226e189bc Signed-off-by: Alexander Bengtsson --- .../architecture/architecture_constraints.hpp | 1 + .../ethosu55/ethos_u55_constraints.cpp | 46 +++++++------- .../ethosu55/ethos_u55_constraints.hpp | 1 + .../ethosu85/ethos_u85_constraints.cpp | 62 +++++++++---------- .../ethosu85/ethos_u85_constraints.hpp | 1 + .../tflite/tflite_supported_operators.cpp | 27 ++++++++ .../tflite/tflite_supported_operators.hpp | 1 + 7 files changed, 84 insertions(+), 55 deletions(-) diff --git a/ethosu/regor/architecture/architecture_constraints.hpp b/ethosu/regor/architecture/architecture_constraints.hpp index 2eed6e26..7003e0a9 100644 --- a/ethosu/regor/architecture/architecture_constraints.hpp +++ b/ethosu/regor/architecture/architecture_constraints.hpp @@ -132,6 +132,7 @@ public: virtual bool SupportsElementwiseLeakyRelu(bool quantized, DataType type) = 0; virtual bool SupportsRescale(DataType fromType, DataType toType) = 0; virtual Flags OperatorQuery(OpType opType, const ArchOperatorQuery *query = nullptr, ArchRequirements *req = nullptr) = 0; + virtual bool SupportedZeroPoint(int64_t zp, TensorUsage usage, DataType dType, OpType opType) = 0; }; } // namespace regor diff --git a/ethosu/regor/architecture/ethosu55/ethos_u55_constraints.cpp b/ethosu/regor/architecture/ethosu55/ethos_u55_constraints.cpp index 8734442b..5bb3c4c4 100644 --- a/ethosu/regor/architecture/ethosu55/ethos_u55_constraints.cpp +++ b/ethosu/regor/architecture/ethosu55/ethos_u55_constraints.cpp @@ -278,34 +278,32 @@ bool EthosU55Constraints::SupportedDtypes(OpType opType, DataType ifmType, DataT return true; } -namespace +// Validate that zero-points are supported +bool EthosU55Constraints::SupportedZeroPoint(int64_t zp, TensorUsage usage, DataType dType, OpType opType) { -// Validate that IFM zero-points are supported based on ifmType and opType -bool SupportedIfmZeroPoint(int64_t zp, DataType ifmType, OpType opType) -{ - // must be zero for 32-bit IFM and for CLZ or SHL operations - if ( DataTypeSizeBits(ifmType) == 32 || opType == OpType::CLZ || opType == OpType::SHL ) - { - return zp == 0; - } - return true; -} -// Validate that OFM zero-points are supported based on opType -bool SupportedOfmZeroPoint(int64_t zp, DataType ofmType, OpType opType) -{ - // must be zero for CLZ or SHL operations - if ( opType == OpType::CLZ || opType == OpType::SHL ) + if ( IsIFM(usage) ) { - return zp == 0; + // must be zero for 32-bit IFM and for CLZ or SHL operations + if ( DataTypeSizeBits(dType) == 32 || opType == OpType::CLZ || opType == OpType::SHL ) + { + return zp == 0; + } } - // must be zero for 32-bit OFM unless op is an activation - if ( DataTypeSizeBits(ofmType) == 32 && !IsActivation(opType) ) + else if ( IsOFM(usage) ) { - return zp == 0; + // must be zero for CLZ or SHL operations + if ( opType == OpType::CLZ || opType == OpType::SHL ) + { + return zp == 0; + } + // must be zero for 32-bit OFM unless op is an activation + if ( DataTypeSizeBits(dType) == 32 && !IsActivation(opType) ) + { + return zp == 0; + } } return true; } -} // namespace Flags EthosU55Constraints::OperatorQuery(OpType opType, const ArchOperatorQuery *query, ArchRequirements *req) { @@ -374,21 +372,21 @@ Flags EthosU55Constraints::OperatorQuery(OpType opType, const ArchO { for ( auto zp : query->ifm[0].quantization.zeroPoints ) { - if ( !SupportedIfmZeroPoint(zp, ifmType, opType) ) + if ( !SupportedZeroPoint(zp, TensorUsage::IFM0, ifmType, opType) ) { return QueryResult::Unsupported; } } for ( auto zp : query->ifm[1].quantization.zeroPoints ) { - if ( !SupportedIfmZeroPoint(zp, ifm2Type, opType) ) + if ( !SupportedZeroPoint(zp, TensorUsage::IFM1, ifm2Type, opType) ) { return QueryResult::Unsupported; } } for ( auto zp : query->ofm.quantization.zeroPoints ) { - if ( !SupportedOfmZeroPoint(zp, ofmType, opType) ) + if ( !SupportedZeroPoint(zp, TensorUsage::OFM, ofmType, opType) ) { return QueryResult::Unsupported; } diff --git a/ethosu/regor/architecture/ethosu55/ethos_u55_constraints.hpp b/ethosu/regor/architecture/ethosu55/ethos_u55_constraints.hpp index 3ae01c3b..13bb8740 100644 --- a/ethosu/regor/architecture/ethosu55/ethos_u55_constraints.hpp +++ b/ethosu/regor/architecture/ethosu55/ethos_u55_constraints.hpp @@ -37,6 +37,7 @@ public: bool SupportsElementwiseLeakyRelu(bool quantized, DataType type) override; bool SupportsRescale(DataType fromType, DataType toType) override; Flags OperatorQuery(OpType opType, const ArchOperatorQuery *query, ArchRequirements *req) override; + bool SupportedZeroPoint(int64_t zp, TensorUsage usage, DataType dtype, OpType opType) override; private: bool SupportedDtypes(OpType opType, DataType ifmType, DataType ifm2Type, DataType ofmType); diff --git a/ethosu/regor/architecture/ethosu85/ethos_u85_constraints.cpp b/ethosu/regor/architecture/ethosu85/ethos_u85_constraints.cpp index 443b9295..b9b72156 100644 --- a/ethosu/regor/architecture/ethosu85/ethos_u85_constraints.cpp +++ b/ethosu/regor/architecture/ethosu85/ethos_u85_constraints.cpp @@ -262,40 +262,40 @@ bool EthosU85Constraints::SupportedDtypes(OpType opType, DataType ifmType, DataT } return true; } -namespace -{ -// Validate that IFM zero-points are supported based on ifmType -bool SupportedIfmZeroPoint(int64_t zp, DataType ifmType) -{ - switch ( ifmType ) - { - case DataType::Int8: - return (zp >= -128) && (zp <= 127); - break; - case DataType::UInt8: - return (zp >= 0) && (zp <= 255); - break; - case DataType::UInt16: - return (zp == 0) || (zp == 32768); - break; - default: - return zp == 0; - } - return false; -} -// Validate that OFM zero-points are supported based on ofmType -bool SupportedOfmZeroPoint(int64_t zp, DataType ofmType) + +// Validate that zero-points are supported +bool EthosU85Constraints::SupportedZeroPoint(int64_t zp, TensorUsage usage, DataType dType, OpType opType) { - if ( IsSignedInteger(ofmType) ) + if ( IsIFM(usage) ) { - return zp >= -128 && zp <= 127; + switch ( dType ) + { + case DataType::Int8: + return (zp >= -128) && (zp <= 127); + break; + case DataType::UInt8: + return (zp >= 0) && (zp <= 255); + break; + case DataType::UInt16: + return (zp == 0) || (zp == 32768); + break; + default: + return zp == 0; + } } - else + else if ( IsOFM(usage) ) { - return (zp == 32768) || (zp >= 0 && zp <= 255); + if ( IsSignedInteger(dType) ) + { + return zp >= -128 && zp <= 127; + } + else + { + return (zp == 32768) || (zp >= 0 && zp <= 255); + } } + return true; } -} // namespace Flags EthosU85Constraints::OperatorQuery(OpType opType, const ArchOperatorQuery *query, ArchRequirements *req) { @@ -393,21 +393,21 @@ Flags EthosU85Constraints::OperatorQuery(OpType opType, const ArchO { for ( auto zp : query->ifm[0].quantization.zeroPoints ) { - if ( !SupportedIfmZeroPoint(zp, ifmType) ) + if ( !SupportedZeroPoint(zp, TensorUsage::IFM0, ifmType, opType) ) { return QueryResult::Unsupported; } } for ( auto zp : query->ifm[1].quantization.zeroPoints ) { - if ( !SupportedIfmZeroPoint(zp, ifm2Type) ) + if ( !SupportedZeroPoint(zp, TensorUsage::IFM1, ifm2Type, opType) ) { return QueryResult::Unsupported; } } for ( auto zp : query->ofm.quantization.zeroPoints ) { - if ( !SupportedOfmZeroPoint(zp, ofmType) ) + if ( !SupportedZeroPoint(zp, TensorUsage::OFM, ofmType, opType) ) { return QueryResult::Unsupported; } diff --git a/ethosu/regor/architecture/ethosu85/ethos_u85_constraints.hpp b/ethosu/regor/architecture/ethosu85/ethos_u85_constraints.hpp index 66aa057b..bc66c2a0 100644 --- a/ethosu/regor/architecture/ethosu85/ethos_u85_constraints.hpp +++ b/ethosu/regor/architecture/ethosu85/ethos_u85_constraints.hpp @@ -36,6 +36,7 @@ public: bool SupportsElementwiseLeakyRelu(bool quantized, DataType type) override { return true; }; bool SupportsRescale(DataType fromType, DataType toType) override; Flags OperatorQuery(OpType opType, const ArchOperatorQuery *query, ArchRequirements *req) override; + bool SupportedZeroPoint(int64_t zp, TensorUsage usage, DataType dtype, OpType opType) override; private: bool SupportedDtypes(OpType opType, DataType ifmType, DataType ifm2Type, DataType ofmType); diff --git a/ethosu/regor/tflite/tflite_supported_operators.cpp b/ethosu/regor/tflite/tflite_supported_operators.cpp index 079d9fda..7fb733d8 100644 --- a/ethosu/regor/tflite/tflite_supported_operators.cpp +++ b/ethosu/regor/tflite/tflite_supported_operators.cpp @@ -259,6 +259,32 @@ bool TfLiteSupportedOperators::ConstraintMatchingQuantization(const Operation *o return true; } +bool TfLiteSupportedOperators::ConstraintZeroPoints(const Operation *op) +{ + OpType opType = op->Type(); + // zeroPoints are ignored for the following operations to align with reference + if ( opType == OpType::AvgPool || opType == OpType::Resize || opType == OpType::CLZ || opType == OpType::SHL || opType == OpType::Div ) + { + return true; + } + for ( const auto *list : {&op->Inputs(), &op->Outputs()} ) + { + for ( const auto &[usage, conn] : list->pairs() ) + { + DataType dType = conn.tensor->Type(); + for ( auto zp : conn.quantization.zeroPoints ) + { + if ( !_archConstraints->SupportedZeroPoint(zp, usage, dType, opType) ) + { + Failure(op, fmt::format("tensor {} has unsupported zeroPoint: {}", conn.tensor->Name(), zp)); + return false; + } + } + } + } + return true; +} + bool TfLiteSupportedOperators::ConstraintWeightsPrecision(const Operation *op) { const char *constraint = "Weight tensors must be 8-bit precision"; @@ -923,6 +949,7 @@ TfLiteSupportedOperators::TfLiteSupportedOperators(IArchitectureConstraints *con &TfLiteSupportedOperators::ConstraintTensQuantized, &TfLiteSupportedOperators::ConstraintPerAxisQuant, &TfLiteSupportedOperators::ConstraintMatchingQuantization, + &TfLiteSupportedOperators::ConstraintZeroPoints, &TfLiteSupportedOperators::ConstraintWeightsPrecision, &TfLiteSupportedOperators::ConstraintWeightSum, &TfLiteSupportedOperators::ConstraintBias, diff --git a/ethosu/regor/tflite/tflite_supported_operators.hpp b/ethosu/regor/tflite/tflite_supported_operators.hpp index 4b668abf..9ad96127 100644 --- a/ethosu/regor/tflite/tflite_supported_operators.hpp +++ b/ethosu/regor/tflite/tflite_supported_operators.hpp @@ -59,6 +59,7 @@ private: bool ConstraintFCWeightShape(const Operation *op); bool ConstraintPerAxisQuant(const Operation *op); bool ConstraintMatchingQuantization(const Operation *op); + bool ConstraintZeroPoints(const Operation *op); bool ConstraintDepthMultiplier(const Operation *op); bool ConstraintWeightsPrecision(const Operation *op); bool ConstraintWeightSum(const Operation *op); -- GitLab From 3aae6aabfe9e9fca834784717f7d49bd074b7b5a Mon Sep 17 00:00:00 2001 From: Alexander Bengtsson Date: Thu, 24 Apr 2025 12:42:15 +0200 Subject: [PATCH 2/2] MLBEDSW-10736: Reject FullyConnected with too large tensors Change-Id: Ibde56bf108f8ef41b564fa9771a3cf4dd5b56896 Signed-off-by: Alexander Bengtsson --- .../regor/tflite/tflite_supported_operators.cpp | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/ethosu/regor/tflite/tflite_supported_operators.cpp b/ethosu/regor/tflite/tflite_supported_operators.cpp index 7fb733d8..e09c5020 100644 --- a/ethosu/regor/tflite/tflite_supported_operators.cpp +++ b/ethosu/regor/tflite/tflite_supported_operators.cpp @@ -191,7 +191,7 @@ bool TfLiteSupportedOperators::ConstraintTensQuantized(const Operation *op) bool TfLiteSupportedOperators::ConstraintFCWeightShape(const Operation *op) { - const char *constraint = "FullyConnected weights must be on the form I,1,1,..,1,O"; + const char *constraint = "FullyConnected weights must be on the form O,1,1,..,1,I"; if ( op->Type() != OpType::FullyConnected ) { return true; @@ -206,6 +206,19 @@ bool TfLiteSupportedOperators::ConstraintFCWeightShape(const Operation *op) Failure(op, fmt::format("Unsupported weights shape: {}", shape.ToString()), constraint); return false; } + + // IC and OC must be smaller than 2^16 + // TODO MLBEDSW-10739: Decompose FullyConnected + if ( shape[0] > (1 << 16) ) + { + Failure(op, fmt::format("Output channels: {}", shape[0]), "Output channels must be less than 2^16"); + return false; + } + if ( shape[-1] > (1 << 16) ) + { + Failure(op, fmt::format("Input channels: {}", shape[-1]), "Input channels must be less than 2^16"); + return false; + } return true; } -- GitLab