diff --git a/ethosu/regor/architecture/architecture_constraints.hpp b/ethosu/regor/architecture/architecture_constraints.hpp index adad6f59962dd5bff74575320bf0ac30b87b9868..dc6dd867b7fc74ab6d34bd026b06ca84489531e7 100644 --- a/ethosu/regor/architecture/architecture_constraints.hpp +++ b/ethosu/regor/architecture/architecture_constraints.hpp @@ -43,6 +43,7 @@ struct ArchFM Shape shape; DataType type = {}; TensorFormat format = {}; + Quantization quantization = {}; }; struct ArchOperatorQuery diff --git a/ethosu/regor/architecture/ethosu85/ethos_u85_constraints.cpp b/ethosu/regor/architecture/ethosu85/ethos_u85_constraints.cpp index 67270d9ea915d6a32a7c91aee1af0148e0339a3a..443b9295518e992b1016fc2c3e03d37c3053c694 100644 --- a/ethosu/regor/architecture/ethosu85/ethos_u85_constraints.cpp +++ b/ethosu/regor/architecture/ethosu85/ethos_u85_constraints.cpp @@ -262,6 +262,40 @@ bool EthosU85Constraints::SupportedDtypes(OpType opType, DataType ifmType, DataT } return true; } +namespace +{ +// Validate that IFM zero-points are supported based on ifmType +bool SupportedIfmZeroPoint(int64_t zp, DataType ifmType) +{ + switch ( ifmType ) + { + case DataType::Int8: + return (zp >= -128) && (zp <= 127); + break; + case DataType::UInt8: + return (zp >= 0) && (zp <= 255); + break; + case DataType::UInt16: + return (zp == 0) || (zp == 32768); + break; + default: + return zp == 0; + } + return false; +} +// Validate that OFM zero-points are supported based on ofmType +bool SupportedOfmZeroPoint(int64_t zp, DataType ofmType) +{ + if ( IsSignedInteger(ofmType) ) + { + return zp >= -128 && zp <= 127; + } + else + { + return (zp == 32768) || (zp >= 0 && zp <= 255); + } +} +} // namespace Flags EthosU85Constraints::OperatorQuery(OpType opType, const ArchOperatorQuery *query, ArchRequirements *req) { @@ -342,7 +376,10 @@ Flags EthosU85Constraints::OperatorQuery(OpType opType, const ArchO const auto &ifmShape = query->ifm[0].shape; const auto &ifm2Shape = query->ifm[1].shape; const auto &ofmShape = query->ofm.shape; - bool typeInfo = (query->ifm[0].type != DataType::None && query->ofm.type != DataType::None); + const auto ifmType = query->ifm[0].type; + const auto ifm2Type = query->ifm[1].type; + const auto ofmType = query->ofm.type; + bool typeInfo = (ifmType != DataType::None && ofmType != DataType::None); bool shapeInfo = (ifmShape && ofmShape); if ( !typeInfo || !shapeInfo || !query->kernel ) @@ -351,6 +388,32 @@ Flags EthosU85Constraints::OperatorQuery(OpType opType, const ArchO result.Set(QueryResult::Constrained); } + // Validate zeroPoints + if ( typeInfo ) + { + for ( auto zp : query->ifm[0].quantization.zeroPoints ) + { + if ( !SupportedIfmZeroPoint(zp, ifmType) ) + { + return QueryResult::Unsupported; + } + } + for ( auto zp : query->ifm[1].quantization.zeroPoints ) + { + if ( !SupportedIfmZeroPoint(zp, ifm2Type) ) + { + return QueryResult::Unsupported; + } + } + for ( auto zp : query->ofm.quantization.zeroPoints ) + { + if ( !SupportedOfmZeroPoint(zp, ofmType) ) + { + return QueryResult::Unsupported; + } + } + } + // Validate dataTypes if ( typeInfo && !SupportedDtypes(opType, query->ifm[0].type, query->ifm[1].type, query->ofm.type) ) { diff --git a/ethosu/regor/compiler/graphir_optimiser.cpp b/ethosu/regor/compiler/graphir_optimiser.cpp index 4dd2305c6fb7f3d0f3c39a2ac0b781f4b677ac7e..5d30cbbec0550fa85fcaa17556bdf8f82eb3390b 100644 --- a/ethosu/regor/compiler/graphir_optimiser.cpp +++ b/ethosu/regor/compiler/graphir_optimiser.cpp @@ -1646,8 +1646,8 @@ Operation *GraphIrOptimiser::MergeTransposes(Graph *const graph, Operation *cons ArchOperatorQuery query; ArchRequirements req; query.transposeMask = mergedTranspose; - Set(query.ifm[0], ifmConn->tensor.get()); - Set(query.ofm, ofmConn->tensor.get()); + Set(query.ifm[0], ifmConn); + Set(query.ofm, ofmConn); if ( _constraints->OperatorQuery(OpType::Transpose, &query, &req).Any(QueryResult::Native) ) { // only merge the transpose if the new mask is natively supported diff --git a/ethosu/regor/compiler/operation_util.hpp b/ethosu/regor/compiler/operation_util.hpp index 2d32aa56b287b53d3e44b40499a87549f0167314..d3fc1edcfa4821c03e3b1ec878ca27d684eaddad 100644 --- a/ethosu/regor/compiler/operation_util.hpp +++ b/ethosu/regor/compiler/operation_util.hpp @@ -282,12 +282,13 @@ inline bool IsScalingValidAndEqual(const TensorConnection &a, const TensorConnec #undef FOR_ALL_INT_TYPES -inline ArchFM &Set(ArchFM &fm, const Tensor *src) +inline ArchFM &Set(ArchFM &fm, const TensorConnection *conn) { - if ( src ) + if ( conn ) { - fm.type = src->Type(); - fm.shape = src->StorageShape(); + fm.type = conn->tensor->Type(); + fm.shape = conn->shape; + fm.quantization = conn->quantization; } return fm; } diff --git a/ethosu/regor/compiler/scheduler_decompose.hpp b/ethosu/regor/compiler/scheduler_decompose.hpp index 8853c7979ade75de9f4f26d720f9509237b25847..8850241134c9181de1238d4f354a685bc2874d5b 100644 --- a/ethosu/regor/compiler/scheduler_decompose.hpp +++ b/ethosu/regor/compiler/scheduler_decompose.hpp @@ -55,6 +55,7 @@ inline ArchFM &Set(ArchFM &fm, const SchedulerConnection *conn) fm.type = conn->tensor->dataType; fm.shape = conn->SliceShape(); fm.format = conn->tensor->format; + fm.quantization = conn->quantization; } return fm; } diff --git a/ethosu/regor/compiler/tflite_graph_optimiser.cpp b/ethosu/regor/compiler/tflite_graph_optimiser.cpp index b0758ae20cd740c71198c1f02418375ff77fd907..eec46ff18948d1e34dadb49ad2536d9b16a9c3a2 100644 --- a/ethosu/regor/compiler/tflite_graph_optimiser.cpp +++ b/ethosu/regor/compiler/tflite_graph_optimiser.cpp @@ -1969,8 +1969,8 @@ Operation *TFLiteGraphOptimiser::ConvertTanhSigmoidToLUT(Graph *const, Operation } ArchOperatorQuery query; - Set(query.ifm[0], ifm); - Set(query.ofm, operation->OFM()); + Set(query.ifm[0], ifmConn); + Set(query.ofm, operation->Output(TensorUsage::OFM)); ArchRequirements req; auto qresult = _constraints->OperatorQuery(opType, &query, &req); assert(qresult.Any(QueryResult::Native));