diff --git a/ethosu/regor/architecture/architecture_constraints.hpp b/ethosu/regor/architecture/architecture_constraints.hpp index 2eed6e266c456a84fa8ebc4ca147d74101bbd478..7003e0a943dce172ed05509f65029711de301a0a 100644 --- a/ethosu/regor/architecture/architecture_constraints.hpp +++ b/ethosu/regor/architecture/architecture_constraints.hpp @@ -132,6 +132,7 @@ public: virtual bool SupportsElementwiseLeakyRelu(bool quantized, DataType type) = 0; virtual bool SupportsRescale(DataType fromType, DataType toType) = 0; virtual Flags OperatorQuery(OpType opType, const ArchOperatorQuery *query = nullptr, ArchRequirements *req = nullptr) = 0; + virtual bool SupportedZeroPoint(int64_t zp, TensorUsage usage, DataType dType, OpType opType) = 0; }; } // namespace regor diff --git a/ethosu/regor/architecture/ethosu55/ethos_u55_constraints.cpp b/ethosu/regor/architecture/ethosu55/ethos_u55_constraints.cpp index 8734442bba27e26ac55dcd378a7f8651dfa47d24..5bb3c4c489daccf9cb71dc5e0543649ba33e0a7a 100644 --- a/ethosu/regor/architecture/ethosu55/ethos_u55_constraints.cpp +++ b/ethosu/regor/architecture/ethosu55/ethos_u55_constraints.cpp @@ -278,34 +278,32 @@ bool EthosU55Constraints::SupportedDtypes(OpType opType, DataType ifmType, DataT return true; } -namespace +// Validate that zero-points are supported +bool EthosU55Constraints::SupportedZeroPoint(int64_t zp, TensorUsage usage, DataType dType, OpType opType) { -// Validate that IFM zero-points are supported based on ifmType and opType -bool SupportedIfmZeroPoint(int64_t zp, DataType ifmType, OpType opType) -{ - // must be zero for 32-bit IFM and for CLZ or SHL operations - if ( DataTypeSizeBits(ifmType) == 32 || opType == OpType::CLZ || opType == OpType::SHL ) - { - return zp == 0; - } - return true; -} -// Validate that OFM zero-points are supported based on opType -bool SupportedOfmZeroPoint(int64_t zp, DataType ofmType, OpType opType) -{ - // must be zero for CLZ or SHL operations - if ( opType == OpType::CLZ || opType == OpType::SHL ) + if ( IsIFM(usage) ) { - return zp == 0; + // must be zero for 32-bit IFM and for CLZ or SHL operations + if ( DataTypeSizeBits(dType) == 32 || opType == OpType::CLZ || opType == OpType::SHL ) + { + return zp == 0; + } } - // must be zero for 32-bit OFM unless op is an activation - if ( DataTypeSizeBits(ofmType) == 32 && !IsActivation(opType) ) + else if ( IsOFM(usage) ) { - return zp == 0; + // must be zero for CLZ or SHL operations + if ( opType == OpType::CLZ || opType == OpType::SHL ) + { + return zp == 0; + } + // must be zero for 32-bit OFM unless op is an activation + if ( DataTypeSizeBits(dType) == 32 && !IsActivation(opType) ) + { + return zp == 0; + } } return true; } -} // namespace Flags EthosU55Constraints::OperatorQuery(OpType opType, const ArchOperatorQuery *query, ArchRequirements *req) { @@ -374,21 +372,21 @@ Flags EthosU55Constraints::OperatorQuery(OpType opType, const ArchO { for ( auto zp : query->ifm[0].quantization.zeroPoints ) { - if ( !SupportedIfmZeroPoint(zp, ifmType, opType) ) + if ( !SupportedZeroPoint(zp, TensorUsage::IFM0, ifmType, opType) ) { return QueryResult::Unsupported; } } for ( auto zp : query->ifm[1].quantization.zeroPoints ) { - if ( !SupportedIfmZeroPoint(zp, ifm2Type, opType) ) + if ( !SupportedZeroPoint(zp, TensorUsage::IFM1, ifm2Type, opType) ) { return QueryResult::Unsupported; } } for ( auto zp : query->ofm.quantization.zeroPoints ) { - if ( !SupportedOfmZeroPoint(zp, ofmType, opType) ) + if ( !SupportedZeroPoint(zp, TensorUsage::OFM, ofmType, opType) ) { return QueryResult::Unsupported; } diff --git a/ethosu/regor/architecture/ethosu55/ethos_u55_constraints.hpp b/ethosu/regor/architecture/ethosu55/ethos_u55_constraints.hpp index 3ae01c3ba255d613d32e0bf3899ef3351067a55e..13bb874075974701befcaea875646802ff531482 100644 --- a/ethosu/regor/architecture/ethosu55/ethos_u55_constraints.hpp +++ b/ethosu/regor/architecture/ethosu55/ethos_u55_constraints.hpp @@ -37,6 +37,7 @@ public: bool SupportsElementwiseLeakyRelu(bool quantized, DataType type) override; bool SupportsRescale(DataType fromType, DataType toType) override; Flags OperatorQuery(OpType opType, const ArchOperatorQuery *query, ArchRequirements *req) override; + bool SupportedZeroPoint(int64_t zp, TensorUsage usage, DataType dtype, OpType opType) override; private: bool SupportedDtypes(OpType opType, DataType ifmType, DataType ifm2Type, DataType ofmType); diff --git a/ethosu/regor/architecture/ethosu85/ethos_u85_constraints.cpp b/ethosu/regor/architecture/ethosu85/ethos_u85_constraints.cpp index 443b9295518e992b1016fc2c3e03d37c3053c694..b9b721560c600c786d69a3fa54fff7d841ce096f 100644 --- a/ethosu/regor/architecture/ethosu85/ethos_u85_constraints.cpp +++ b/ethosu/regor/architecture/ethosu85/ethos_u85_constraints.cpp @@ -262,40 +262,40 @@ bool EthosU85Constraints::SupportedDtypes(OpType opType, DataType ifmType, DataT } return true; } -namespace -{ -// Validate that IFM zero-points are supported based on ifmType -bool SupportedIfmZeroPoint(int64_t zp, DataType ifmType) -{ - switch ( ifmType ) - { - case DataType::Int8: - return (zp >= -128) && (zp <= 127); - break; - case DataType::UInt8: - return (zp >= 0) && (zp <= 255); - break; - case DataType::UInt16: - return (zp == 0) || (zp == 32768); - break; - default: - return zp == 0; - } - return false; -} -// Validate that OFM zero-points are supported based on ofmType -bool SupportedOfmZeroPoint(int64_t zp, DataType ofmType) + +// Validate that zero-points are supported +bool EthosU85Constraints::SupportedZeroPoint(int64_t zp, TensorUsage usage, DataType dType, OpType opType) { - if ( IsSignedInteger(ofmType) ) + if ( IsIFM(usage) ) { - return zp >= -128 && zp <= 127; + switch ( dType ) + { + case DataType::Int8: + return (zp >= -128) && (zp <= 127); + break; + case DataType::UInt8: + return (zp >= 0) && (zp <= 255); + break; + case DataType::UInt16: + return (zp == 0) || (zp == 32768); + break; + default: + return zp == 0; + } } - else + else if ( IsOFM(usage) ) { - return (zp == 32768) || (zp >= 0 && zp <= 255); + if ( IsSignedInteger(dType) ) + { + return zp >= -128 && zp <= 127; + } + else + { + return (zp == 32768) || (zp >= 0 && zp <= 255); + } } + return true; } -} // namespace Flags EthosU85Constraints::OperatorQuery(OpType opType, const ArchOperatorQuery *query, ArchRequirements *req) { @@ -393,21 +393,21 @@ Flags EthosU85Constraints::OperatorQuery(OpType opType, const ArchO { for ( auto zp : query->ifm[0].quantization.zeroPoints ) { - if ( !SupportedIfmZeroPoint(zp, ifmType) ) + if ( !SupportedZeroPoint(zp, TensorUsage::IFM0, ifmType, opType) ) { return QueryResult::Unsupported; } } for ( auto zp : query->ifm[1].quantization.zeroPoints ) { - if ( !SupportedIfmZeroPoint(zp, ifm2Type) ) + if ( !SupportedZeroPoint(zp, TensorUsage::IFM1, ifm2Type, opType) ) { return QueryResult::Unsupported; } } for ( auto zp : query->ofm.quantization.zeroPoints ) { - if ( !SupportedOfmZeroPoint(zp, ofmType) ) + if ( !SupportedZeroPoint(zp, TensorUsage::OFM, ofmType, opType) ) { return QueryResult::Unsupported; } diff --git a/ethosu/regor/architecture/ethosu85/ethos_u85_constraints.hpp b/ethosu/regor/architecture/ethosu85/ethos_u85_constraints.hpp index 66aa057b46620d04745d08fbb6ab26634983fbc2..bc66c2a0358d52f570e187e536834845f75a99a1 100644 --- a/ethosu/regor/architecture/ethosu85/ethos_u85_constraints.hpp +++ b/ethosu/regor/architecture/ethosu85/ethos_u85_constraints.hpp @@ -36,6 +36,7 @@ public: bool SupportsElementwiseLeakyRelu(bool quantized, DataType type) override { return true; }; bool SupportsRescale(DataType fromType, DataType toType) override; Flags OperatorQuery(OpType opType, const ArchOperatorQuery *query, ArchRequirements *req) override; + bool SupportedZeroPoint(int64_t zp, TensorUsage usage, DataType dtype, OpType opType) override; private: bool SupportedDtypes(OpType opType, DataType ifmType, DataType ifm2Type, DataType ofmType); diff --git a/ethosu/regor/tflite/tflite_supported_operators.cpp b/ethosu/regor/tflite/tflite_supported_operators.cpp index 079d9fda96dbae4e6d6067198f43802396055fba..e09c50204275d3aee465664fe253d39de265e54d 100644 --- a/ethosu/regor/tflite/tflite_supported_operators.cpp +++ b/ethosu/regor/tflite/tflite_supported_operators.cpp @@ -191,7 +191,7 @@ bool TfLiteSupportedOperators::ConstraintTensQuantized(const Operation *op) bool TfLiteSupportedOperators::ConstraintFCWeightShape(const Operation *op) { - const char *constraint = "FullyConnected weights must be on the form I,1,1,..,1,O"; + const char *constraint = "FullyConnected weights must be on the form O,1,1,..,1,I"; if ( op->Type() != OpType::FullyConnected ) { return true; @@ -206,6 +206,19 @@ bool TfLiteSupportedOperators::ConstraintFCWeightShape(const Operation *op) Failure(op, fmt::format("Unsupported weights shape: {}", shape.ToString()), constraint); return false; } + + // IC and OC must be smaller than 2^16 + // TODO MLBEDSW-10739: Decompose FullyConnected + if ( shape[0] > (1 << 16) ) + { + Failure(op, fmt::format("Output channels: {}", shape[0]), "Output channels must be less than 2^16"); + return false; + } + if ( shape[-1] > (1 << 16) ) + { + Failure(op, fmt::format("Input channels: {}", shape[-1]), "Input channels must be less than 2^16"); + return false; + } return true; } @@ -259,6 +272,32 @@ bool TfLiteSupportedOperators::ConstraintMatchingQuantization(const Operation *o return true; } +bool TfLiteSupportedOperators::ConstraintZeroPoints(const Operation *op) +{ + OpType opType = op->Type(); + // zeroPoints are ignored for the following operations to align with reference + if ( opType == OpType::AvgPool || opType == OpType::Resize || opType == OpType::CLZ || opType == OpType::SHL || opType == OpType::Div ) + { + return true; + } + for ( const auto *list : {&op->Inputs(), &op->Outputs()} ) + { + for ( const auto &[usage, conn] : list->pairs() ) + { + DataType dType = conn.tensor->Type(); + for ( auto zp : conn.quantization.zeroPoints ) + { + if ( !_archConstraints->SupportedZeroPoint(zp, usage, dType, opType) ) + { + Failure(op, fmt::format("tensor {} has unsupported zeroPoint: {}", conn.tensor->Name(), zp)); + return false; + } + } + } + } + return true; +} + bool TfLiteSupportedOperators::ConstraintWeightsPrecision(const Operation *op) { const char *constraint = "Weight tensors must be 8-bit precision"; @@ -923,6 +962,7 @@ TfLiteSupportedOperators::TfLiteSupportedOperators(IArchitectureConstraints *con &TfLiteSupportedOperators::ConstraintTensQuantized, &TfLiteSupportedOperators::ConstraintPerAxisQuant, &TfLiteSupportedOperators::ConstraintMatchingQuantization, + &TfLiteSupportedOperators::ConstraintZeroPoints, &TfLiteSupportedOperators::ConstraintWeightsPrecision, &TfLiteSupportedOperators::ConstraintWeightSum, &TfLiteSupportedOperators::ConstraintBias, diff --git a/ethosu/regor/tflite/tflite_supported_operators.hpp b/ethosu/regor/tflite/tflite_supported_operators.hpp index 4b668abfd16f7097083ccf89001eec860f6e7262..9ad96127dba9259dc916c9537b3f05f06f092801 100644 --- a/ethosu/regor/tflite/tflite_supported_operators.hpp +++ b/ethosu/regor/tflite/tflite_supported_operators.hpp @@ -59,6 +59,7 @@ private: bool ConstraintFCWeightShape(const Operation *op); bool ConstraintPerAxisQuant(const Operation *op); bool ConstraintMatchingQuantization(const Operation *op); + bool ConstraintZeroPoints(const Operation *op); bool ConstraintDepthMultiplier(const Operation *op); bool ConstraintWeightsPrecision(const Operation *op); bool ConstraintWeightSum(const Operation *op);