diff --git a/ethosu/regor/architecture/architecture_constraints.hpp b/ethosu/regor/architecture/architecture_constraints.hpp index e5f5964c3d5251140406d8fff595365bf38e1ecc..9790374b917ed2d199b3126f25b914cd0c549a77 100644 --- a/ethosu/regor/architecture/architecture_constraints.hpp +++ b/ethosu/regor/architecture/architecture_constraints.hpp @@ -43,7 +43,7 @@ struct ArchFM Shape shape; DataType type = {}; TensorFormat format = {}; - Quantization quantization = {}; + const Quantization *quantization = nullptr; }; struct ArchOperatorQuery @@ -59,12 +59,10 @@ struct ArchOperatorQuery enum class ArchRequirement { - None = 0, - ScratchTensor = 1 << 0, - OutputFormat = 1 << 1, - InputFormat = 1 << 2, - OpSubstitution = 1 << 3, - Decompose = 1 << 4, + None = 0x00, + Tensor = 0x01, // Tensor requirement + OpSubstitution = 0x02, // Operator substitution + Decompose = 0x04, // Decompose }; enum class ArchProperty @@ -80,18 +78,19 @@ enum class ArchProperty Scaling = 1 << 7, }; +struct ArchTensorRequirement +{ + const ArchTensorRequirement *next = nullptr; + TensorUsage usage = TensorUsage::None; + TensorFormat format = TensorFormat::Unknown; + DataType type = DataType::None; + Shape shape; +}; + struct ArchRequirements { Flags req; - struct - { - Shape size; - DataType type = DataType::None; - TensorFormat format = TensorFormat::Unknown; - } scratch; - TensorFormat ifmFormat = TensorFormat::Unknown; - TensorFormat ifm1Format = TensorFormat::Unknown; - TensorFormat ofmFormat = TensorFormat::Unknown; + ArchTensorRequirement tensor; OpType substitution = OpType::None; Flags decomposeProps; }; @@ -138,4 +137,32 @@ public: virtual bool SupportedZeroPoint(int64_t zp, TensorUsage usage, DataType dType, OpType opType) = 0; }; +inline void Set(ArchTensorRequirement &req, TensorUsage usage, TensorFormat format) +{ + req.usage = usage; + req.format = format; + req.next = nullptr; +} + +inline void Set(ArchTensorRequirement &req, TensorUsage usage, DataType type, TensorFormat format) +{ + Set(req, usage, format); + req.type = type; +} + +inline void Set(ArchTensorRequirement &req, TensorUsage usage, DataType type, TensorFormat format, const Shape &shape) +{ + Set(req, usage, type, format); + req.shape = shape; +} + +inline const ArchTensorRequirement *Get(const ArchTensorRequirement *req, TensorUsage usage) +{ + while ( req->usage != usage && req->next != nullptr ) + { + req = req->next; + } + return req; +} + } // namespace regor diff --git a/ethosu/regor/architecture/ethosu55/ethos_u55_constraints.cpp b/ethosu/regor/architecture/ethosu55/ethos_u55_constraints.cpp index f203354fcc808ef44f9075607b6812b4df8b44b2..f828d305b90ccbdd9547d52436d5aa299cd25737 100644 --- a/ethosu/regor/architecture/ethosu55/ethos_u55_constraints.cpp +++ b/ethosu/regor/architecture/ethosu55/ethos_u55_constraints.cpp @@ -24,38 +24,43 @@ namespace regor { // Table of allowed ifm/ofm data type combinations for each HWOp -static const std::unordered_map>> s_opDataTypeSupport = { +static const std::array s_defaultAllTypes = {DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32}; +static const std::array s_defaultUInt8Only = {DataType::UInt8}; +static const std::array s_defaultInt8Only = {DataType::Int8}; +static const std::array s_defaultInt16Only = {DataType::Int16}; + +static const std::unordered_map>> s_opDataTypeSupport = { {EthosU55NpuOp::Convolution, // HWOp { // IFM data type | OFM data type(s) - {DataType::UInt8, {DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32}}, - {DataType::Int8, {DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32}}, - {DataType::Int16, {DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32}}, + {DataType::UInt8, s_defaultAllTypes}, + {DataType::Int8, s_defaultAllTypes}, + {DataType::Int16, s_defaultAllTypes}, }}, {EthosU55NpuOp::Depthwise, { - {DataType::UInt8, {DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32}}, - {DataType::Int8, {DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32}}, - {DataType::Int16, {DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32}}, + {DataType::UInt8, s_defaultAllTypes}, + {DataType::Int8, s_defaultAllTypes}, + {DataType::Int16, s_defaultAllTypes}, }}, {EthosU55NpuOp::VectorProduct, { - {DataType::UInt8, {DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32}}, - {DataType::Int8, {DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32}}, - {DataType::Int16, {DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32}}, + {DataType::UInt8, s_defaultAllTypes}, + {DataType::Int8, s_defaultAllTypes}, + {DataType::Int16, s_defaultAllTypes}, }}, {EthosU55NpuOp::Pooling, { - {DataType::UInt8, {DataType::UInt8}}, - {DataType::Int8, {DataType::Int8}}, - {DataType::Int16, {DataType::Int16}}, + {DataType::UInt8, s_defaultUInt8Only}, + {DataType::Int8, s_defaultInt8Only}, + {DataType::Int16, s_defaultInt16Only}, }}, {EthosU55NpuOp::ReduceSum, { - {DataType::UInt8, {DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32}}, - {DataType::Int8, {DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32}}, - {DataType::Int16, {DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32}}, - {DataType::Int32, {DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32}}, + {DataType::UInt8, s_defaultAllTypes}, + {DataType::Int8, s_defaultAllTypes}, + {DataType::Int16, s_defaultAllTypes}, + {DataType::Int32, s_defaultAllTypes}, }}, }; @@ -182,6 +187,15 @@ bool EthosU55Constraints::SupportsRescale(DataType fromType, DataType toType) return true; } + +static const std::array s_validAddMulTypes = {DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32}; +static const std::array s_validMaxAbsTypes = {DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32}; +static const std::array s_validClzShlTypes = {DataType::Int32}; +static const std::array s_validAsrOfmTypes = {DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32}; +static const std::array s_validReverseTypes = {DataType::UInt8, DataType::Int8, DataType::Int16}; + + + bool EthosU55Constraints::SupportedDtypes(OpType opType, DataType ifmType, DataType ifm2Type, DataType ofmType) { auto npuOp = _arch->GetHWOp(opType); @@ -202,6 +216,8 @@ bool EthosU55Constraints::SupportedDtypes(OpType opType, DataType ifmType, DataT return true; } + readonly_span_t ofmTypes; + // Check allowed ifm/ofm type mapping if ( npuOp != EthosU55NpuOp::Elementwise ) { @@ -214,27 +230,23 @@ bool EthosU55Constraints::SupportedDtypes(OpType opType, DataType ifmType, DataT auto &typeMap = map->second; auto ifmEntry = typeMap.find(ifmType); if ( ifmEntry == typeMap.end() ) - { // Unsupported ifm data type - return false; - } - auto &ofmTypes = ifmEntry->second; - if ( 0 == std::count(ofmTypes.begin(), ofmTypes.end(), ofmType) ) - { // Unsupported ofm data type + { + // Unsupported ifm data type return false; } + ofmTypes = ifmEntry->second; } else { - std::vector validIfmTypes; - std::vector validOfmTypes; + readonly_span_t ifmTypes; switch ( opType ) { case OpType::Add: case OpType::Sub: case OpType::Mul: { - validIfmTypes = {DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32}; - validOfmTypes = validIfmTypes; + ifmTypes = s_validAddMulTypes; + ofmTypes = s_validAddMulTypes; } break; case OpType::Minimum: @@ -242,45 +254,50 @@ bool EthosU55Constraints::SupportedDtypes(OpType opType, DataType ifmType, DataT case OpType::LeakyRelu: case OpType::Abs: { - validIfmTypes = {DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32}; - validOfmTypes = {ifmType}; + ifmTypes = s_validMaxAbsTypes; + ofmTypes = s_validMaxAbsTypes; } break; case OpType::CLZ: case OpType::SHL: + { + ifmTypes = s_validClzShlTypes; + ofmTypes = s_validClzShlTypes; + } + break; case OpType::Asr: { - validIfmTypes = {DataType::Int32}; - validOfmTypes = {DataType::Int32}; - if ( opType == OpType::Asr ) - { - validOfmTypes.insert(validOfmTypes.begin(), {DataType::UInt8, DataType::Int8, DataType::Int16}); - } + ifmTypes = s_validClzShlTypes; + ofmTypes = s_validAsrOfmTypes; } break; case OpType::Reverse: { - validIfmTypes = {DataType::UInt8, DataType::Int8, DataType::Int16}; - validOfmTypes = {DataType::UInt8, DataType::Int8, DataType::Int16}; + ifmTypes = s_validReverseTypes; + ofmTypes = s_validReverseTypes; } break; default: assert(false && "Unkown elementwise type"); break; } - if ( 0 == std::count(validIfmTypes.begin(), validIfmTypes.end(), ifmType) ) - { // Unsupported ifm data type + if ( !std::any_of(ifmTypes.begin(), ifmTypes.end(), [&](auto t) { return t == ifmType; }) ) + { + // Unsupported ifm data type return false; } if ( IsBinaryElementwise(opType) && ifm2Type != ifmType ) - { // ifm2 data type must match ifm data type - return false; - } - if ( 0 == std::count(validOfmTypes.begin(), validOfmTypes.end(), ofmType) ) - { // Unsupported ofm data type + { + // ifm2 data type must match ifm data type return false; } } + + if ( !std::any_of(ofmTypes.begin(), ofmTypes.end(), [&](auto t) { return t == ofmType; }) ) + { // Unsupported ofm data type + return false; + } + return true; } @@ -317,24 +334,26 @@ bool EthosU55Constraints::SupportedZeroPoint(int64_t zp, TensorUsage usage, Data return true; } +namespace +{ + +thread_local std::array s_extraTensorReq; + +ArchTensorRequirement *NextTensor(ArchTensorRequirement *tr, unsigned &used) +{ + assert(used < s_extraTensorReq.size()); + ArchTensorRequirement *info = &s_extraTensorReq[used++]; + tr->next = info; + return info; +}; + +} // namespace + Flags EthosU55Constraints::OperatorQuery(OpType opType, const ArchOperatorQuery *query, ArchRequirements *req) { - Flags result = QueryResult::Native; static constexpr int32_t MAX_AXIS = (1 << 16); - // Check hardware-required substitutions first - if ( (opType == OpType::Sigmoid) || (opType == OpType::Tanh) ) - { - if ( query && query->ifm[0].type != DataType::Int16 ) - { - if ( req ) - { - req->req.Set(ArchRequirement::OpSubstitution); - req->substitution = OpType::LUT; - } - result.Set(QueryResult::HasRequirements); - } - } + unsigned usedTensors = 0; if ( opType == OpType::Resize ) { if ( query->ifm[0].shape.ElementsWH() == 1 ) @@ -346,12 +365,10 @@ Flags EthosU55Constraints::OperatorQuery(OpType opType, const ArchO req->req = ArchRequirement::Decompose; req->substitution = OpType::AvgPool; } - result.Set(QueryResult::HasRequirements); - return result; + return QueryResult::NativeHasReq; } - // TransposeConv2D and Conv3D are legalized during decomposition - if ( opType == OpType::TransposeConv2D || opType == OpType::Conv3D ) + else if ( opType == OpType::TransposeConv2D || opType == OpType::Conv3D ) { if ( req ) { @@ -368,16 +385,18 @@ Flags EthosU55Constraints::OperatorQuery(OpType opType, const ArchO } else if ( npuOp == EthosU55NpuOp::Dma ) { - return result; + return QueryResult::Native; } // Short query (no additional detail) if ( !query ) { - // more detailed query might fail + // More detailed query might fail (constrained) return QueryResult::NativeConstrained; } + Flags result = QueryResult::Native; + if ( npuOp == EthosU55NpuOp::ReduceSum ) { // unsupported reduce axis (only C supported) @@ -391,6 +410,19 @@ Flags EthosU55Constraints::OperatorQuery(OpType opType, const ArchO result.Set(QueryResult::HasRequirements); } } + // Check required substitutions first + else if ( (opType == OpType::Sigmoid) || (opType == OpType::Tanh) ) + { + if ( query->ifm[0].type != DataType::Int16 ) + { + if ( req ) + { + req->req.Set(ArchRequirement::OpSubstitution); + req->substitution = OpType::LUT; + } + result.Set(QueryResult::HasRequirements); + } + } const auto &ifmShape = query->ifm[0].shape; const auto &ifm2Shape = query->ifm[1].shape; @@ -407,32 +439,6 @@ Flags EthosU55Constraints::OperatorQuery(OpType opType, const ArchO result.Set(QueryResult::Constrained); } - // Validate zeroPoints - if ( typeInfo ) - { - for ( auto zp : query->ifm[0].quantization.zeroPoints ) - { - if ( !SupportedZeroPoint(zp, TensorUsage::IFM0, ifmType, opType) ) - { - return QueryResult::Unsupported; - } - } - for ( auto zp : query->ifm[1].quantization.zeroPoints ) - { - if ( !SupportedZeroPoint(zp, TensorUsage::IFM1, ifm2Type, opType) ) - { - return QueryResult::Unsupported; - } - } - for ( auto zp : query->ofm.quantization.zeroPoints ) - { - if ( !SupportedZeroPoint(zp, TensorUsage::OFM, ofmType, opType) ) - { - return QueryResult::Unsupported; - } - } - } - // Validate DataTypes if ( typeInfo && !SupportedDtypes(opType, query->ifm[0].type, query->ifm[1].type, query->ofm.type) ) { @@ -476,12 +482,8 @@ Flags EthosU55Constraints::OperatorQuery(OpType opType, const ArchO } // Detailed operator queries - if ( !IsNone(query->transposeMask) ) + if ( opType == OpType::Transpose ) { - if ( opType != OpType::Transpose ) - { - return QueryResult::Unsupported; - } // TODO MLBEDSW-10668: Transpose-implementation does not support large-axis decomposition if ( req && req->decomposeProps.Any(ArchProperty::TensorAxis) ) { @@ -505,14 +507,20 @@ Flags EthosU55Constraints::OperatorQuery(OpType opType, const ArchO } result.Set(QueryResult::HasRequirements); } - // Always set Input/Output format requirements if ( req ) { - req->req.Set(ArchRequirement::OutputFormat, ArchRequirement::InputFormat); - req->ifmFormat = TensorFormat::NHWC; - req->ofmFormat = TensorFormat::NHWC; + req->req.Set(ArchRequirement::Tensor); + Set(req->tensor, TensorUsage::IFM, TensorFormat::NHWC); + Set(*NextTensor(&req->tensor, usedTensors), TensorUsage::OFM, TensorFormat::NHWC); + } + return result; + } + else + { + if ( !IsNone(query->transposeMask) ) + { + return QueryResult::Unsupported; } - result.Set(QueryResult::HasRequirements); } // reverseType::W and reverseType::H are supported @@ -521,20 +529,57 @@ Flags EthosU55Constraints::OperatorQuery(OpType opType, const ArchO return QueryResult::Unsupported; } + // Validate zeroPoints + if ( typeInfo ) + { + if ( query->ifm[0].quantization ) + { + for ( auto zp : query->ifm[0].quantization->zeroPoints ) + { + if ( !SupportedZeroPoint(zp, TensorUsage::IFM0, ifmType, opType) ) + { + return QueryResult::Unsupported; + } + } + } + if ( query->ifm[1].quantization ) + { + for ( auto zp : query->ifm[1].quantization->zeroPoints ) + { + if ( !SupportedZeroPoint(zp, TensorUsage::IFM1, ifm2Type, opType) ) + { + return QueryResult::Unsupported; + } + } + } + if ( query->ofm.quantization ) + { + for ( auto zp : query->ofm.quantization->zeroPoints ) + { + if ( !SupportedZeroPoint(zp, TensorUsage::OFM, ofmType, opType) ) + { + return QueryResult::Unsupported; + } + } + } + } + if ( opType == OpType::MatMul ) { if ( req ) { - req->req.Set(ArchRequirement::ScratchTensor, ArchRequirement::OutputFormat, ArchRequirement::InputFormat); + req->req.Set(ArchRequirement::Tensor); + ArchTensorRequirement *tr = &req->tensor; if ( query->ifm[0].shape ) { - req->scratch.size = query->ifm[0].shape.WithDepth(query->ifm[0].shape.Depth() + 1); - req->scratch.type = DataType::Int32; - req->scratch.format = TensorFormat::NHWC; + req->req.Set(ArchRequirement::Tensor); + Set(*tr, TensorUsage::Scratch, DataType::Int32, TensorFormat::NHWC, + query->ifm[0].shape.WithDepth(query->ifm[0].shape.Depth() + 1)); + tr = NextTensor(tr, usedTensors); } - req->ifmFormat = TensorFormat::Unknown; - req->ifm1Format = TensorFormat::NHWC; // IFM1 and OFM are depth-sliced - req->ofmFormat = TensorFormat::NHWC; // and cannot be addressed if B16 + Set(*tr, TensorUsage::IFM1, TensorFormat::NHWC); + tr = NextTensor(tr, usedTensors); + Set(*tr, TensorUsage::OFM, TensorFormat::NHWC); } result.Set(QueryResult::HasRequirements); } diff --git a/ethosu/regor/architecture/ethosu85/ethos_u85_constraints.cpp b/ethosu/regor/architecture/ethosu85/ethos_u85_constraints.cpp index d2de7f90510a786eae2ba803aa17df924999850b..57bcf3537d9cfbfc4d05b33cfd9d30c846505138 100644 --- a/ethosu/regor/architecture/ethosu85/ethos_u85_constraints.cpp +++ b/ethosu/regor/architecture/ethosu85/ethos_u85_constraints.cpp @@ -26,61 +26,73 @@ namespace regor { +// Table of allowed ifm/ofm data type combinations for each HWOp +static const std::array s_defaultAllTypes = {DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32, DataType::Int64}; +static const std::array s_defaultAllTypesExcl64 = { + DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32, DataType::Int64}; +static const std::array s_reduceMinMaxTypes = { + DataType::Bool8, DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32}; +static const std::array s_poolingTypeBool = {DataType::Bool8, DataType::Int32, DataType::Int64}; +static const std::array s_poolingTypeUInt8 = {DataType::UInt8, DataType::Int32, DataType::Int64}; +static const std::array s_poolingTypeInt8 = {DataType::Int8, DataType::Int32, DataType::Int64}; +static const std::array s_poolingTypeInt16 = {DataType::Int16, DataType::Int32, DataType::Int64}; +static const std::array s_defaultInt32_64 = {DataType::Int32, DataType::Int64}; + // TODO: This table is from the EthosU55/U65 Embedded NPU Interface Specification, it's not completely valid for // Ethos U85 since the allowed data types depend on ifm/ofm as well as selected acc and scaling. -static const std::unordered_map>> s_opDataTypeSupport = { +static const std::unordered_map>> s_opDataTypeSupport = { {EthosU85NpuOp::Convolution, { - {DataType::UInt8, {DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32, DataType::Int64}}, - {DataType::Int8, {DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32, DataType::Int64}}, - {DataType::Int16, {DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32, DataType::Int64}}, + {DataType::UInt8, s_defaultAllTypes}, + {DataType::Int8, s_defaultAllTypes}, + {DataType::Int16, s_defaultAllTypes}, }}, {EthosU85NpuOp::Depthwise, { - {DataType::UInt8, {DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32, DataType::Int64}}, - {DataType::Int8, {DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32, DataType::Int64}}, - {DataType::Int16, {DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32, DataType::Int64}}, + {DataType::UInt8, s_defaultAllTypes}, + {DataType::Int8, s_defaultAllTypes}, + {DataType::Int16, s_defaultAllTypes}, }}, {EthosU85NpuOp::VectorProduct, { - {DataType::UInt8, {DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32, DataType::Int64}}, - {DataType::Int8, {DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32, DataType::Int64}}, - {DataType::Int16, {DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32, DataType::Int64}}, + {DataType::UInt8, s_defaultAllTypes}, + {DataType::Int8, s_defaultAllTypes}, + {DataType::Int16, s_defaultAllTypes}, }}, {EthosU85NpuOp::Pooling, { - {DataType::Bool8, {DataType::Bool8, DataType::Int32, DataType::Int64}}, - {DataType::UInt8, {DataType::UInt8, DataType::Int32, DataType::Int64}}, - {DataType::Int8, {DataType::Int8, DataType::Int32, DataType::Int64}}, - {DataType::Int16, {DataType::Int16, DataType::Int32, DataType::Int64}}, + {DataType::Bool8, s_poolingTypeBool}, + {DataType::UInt8, s_poolingTypeUInt8}, + {DataType::Int8, s_poolingTypeInt8}, + {DataType::Int16, s_poolingTypeInt16}, }}, {EthosU85NpuOp::ReduceMinMax, { - {DataType::Bool8, {DataType::Bool8, DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32}}, - {DataType::UInt8, {DataType::Bool8, DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32}}, - {DataType::Int8, {DataType::Bool8, DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32}}, - {DataType::Int16, {DataType::Bool8, DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32}}, - {DataType::Int32, {DataType::Bool8, DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32}}, + {DataType::Bool8, s_reduceMinMaxTypes}, + {DataType::UInt8, s_reduceMinMaxTypes}, + {DataType::Int8, s_reduceMinMaxTypes}, + {DataType::Int16, s_reduceMinMaxTypes}, + {DataType::Int32, s_reduceMinMaxTypes}, }}, {EthosU85NpuOp::ReduceSum, { - {DataType::UInt8, {DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32}}, - {DataType::Int8, {DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32}}, - {DataType::Int16, {DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32}}, - {DataType::Int32, {DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32}}, + {DataType::UInt8, s_defaultAllTypesExcl64}, + {DataType::Int8, s_defaultAllTypesExcl64}, + {DataType::Int16, s_defaultAllTypesExcl64}, + {DataType::Int32, s_defaultAllTypesExcl64}, }}, {EthosU85NpuOp::ArgMax, { - {DataType::Bool8, {DataType::Int32, DataType::Int64}}, - {DataType::UInt8, {DataType::Int32, DataType::Int64}}, - {DataType::Int8, {DataType::Int32, DataType::Int64}}, - {DataType::Int16, {DataType::Int32, DataType::Int64}}, + {DataType::Bool8, s_defaultInt32_64}, + {DataType::UInt8, s_defaultInt32_64}, + {DataType::Int8, s_defaultInt32_64}, + {DataType::Int16, s_defaultInt32_64}, }}, {EthosU85NpuOp::Resize, { - {DataType::UInt8, {DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32, DataType::Int64}}, - {DataType::Int8, {DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32, DataType::Int64}}, - {DataType::Int16, {DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32, DataType::Int64}}, + {DataType::UInt8, s_defaultAllTypes}, + {DataType::Int8, s_defaultAllTypes}, + {DataType::Int16, s_defaultAllTypes}, }}, }; @@ -234,6 +246,8 @@ bool EthosU85Constraints::SupportedDtypes(OpType opType, DataType ifmType, DataT return true; } + readonly_span_t ofmTypes; + if ( npuOp != EthosU85NpuOp::Elementwise ) { auto map = s_opDataTypeSupport.find(npuOp); @@ -249,17 +263,31 @@ bool EthosU85Constraints::SupportedDtypes(OpType opType, DataType ifmType, DataT // Unsupported ifm data type return false; } - auto &ofmTypes = ifmEntry->second; - if ( 0 == std::count(ofmTypes.begin(), ofmTypes.end(), ofmType) ) - { - // Unsupported ofm data type - return false; - } + ofmTypes = ifmEntry->second; } else { // TODO elementwise + readonly_span_t ifmTypes = s_defaultAllTypes; + ofmTypes = s_defaultAllTypes; + + if ( !std::any_of(ifmTypes.begin(), ifmTypes.end(), [&](auto t) { return t == ifmType; }) ) + { + // Unsupported ifm data type + return false; + } + if ( IsBinaryElementwise(opType) && ifm2Type != ifmType ) + { + // ifm2 data type must match ifm data type + return false; + } } + + if ( !std::any_of(ofmTypes.begin(), ofmTypes.end(), [&](auto t) { return t == ofmType; }) ) + { // Unsupported ofm data type + return false; + } + return true; } @@ -417,25 +445,34 @@ Flags EthosU85Constraints::OperatorQuery(OpType opType, const ArchO // Validate zeroPoints if ( typeInfo ) { - for ( auto zp : query->ifm[0].quantization.zeroPoints ) + if ( query->ifm[0].quantization ) { - if ( !SupportedZeroPoint(zp, TensorUsage::IFM0, ifmType, opType) ) + for ( auto zp : query->ifm[0].quantization->zeroPoints ) { - return QueryResult::Unsupported; + if ( !SupportedZeroPoint(zp, TensorUsage::IFM0, ifmType, opType) ) + { + return QueryResult::Unsupported; + } } } - for ( auto zp : query->ifm[1].quantization.zeroPoints ) + if ( query->ifm[1].quantization ) { - if ( !SupportedZeroPoint(zp, TensorUsage::IFM1, ifm2Type, opType) ) + for ( auto zp : query->ifm[1].quantization->zeroPoints ) { - return QueryResult::Unsupported; + if ( !SupportedZeroPoint(zp, TensorUsage::IFM1, ifm2Type, opType) ) + { + return QueryResult::Unsupported; + } } } - for ( auto zp : query->ofm.quantization.zeroPoints ) + if ( query->ofm.quantization ) { - if ( !SupportedZeroPoint(zp, TensorUsage::OFM, ofmType, opType) ) + for ( auto zp : query->ofm.quantization->zeroPoints ) { - return QueryResult::Unsupported; + if ( !SupportedZeroPoint(zp, TensorUsage::OFM, ofmType, opType) ) + { + return QueryResult::Unsupported; + } } } } @@ -532,7 +569,7 @@ Flags EthosU85Constraints::OperatorQuery(OpType opType, const ArchO } if ( opType == OpType::AvgPool && (k->Size().x > 8 || k->Size().y > 8) && !k->Padding().IsZero() && - query->ofm.quantization.scales.size() ) + query->ofm.quantization->scales.size() ) { if ( req ) { diff --git a/ethosu/regor/common/common.hpp b/ethosu/regor/common/common.hpp index 8aa23450b6603b6175f10705942d0f37924e30ec..74535a4a5d6e8c506814115b474eb3b797e209f0 100644 --- a/ethosu/regor/common/common.hpp +++ b/ethosu/regor/common/common.hpp @@ -189,5 +189,23 @@ constexpr bool is_sorted(const TYPE (&list)[SIZE]) return is_sorted(list, std::less()); } +// Equivalent functionality not available until C++ 20 +template +struct readonly_span_t +{ +private: + const T *_start = nullptr; + const T *_end = nullptr; + +public: + readonly_span_t(){}; + template + readonly_span_t(const std::array &a) noexcept : _start(&a[0]), _end(&a[0] + SIZE) + { + } + readonly_span_t(const T *p, size_t size) : _start(p), _end(p + size) {} + const T *begin() const { return _start; } + const T *end() const { return _end; } +}; } // namespace regor diff --git a/ethosu/regor/compiler/operation.cpp b/ethosu/regor/compiler/operation.cpp index e4b2af9bc49a66441d81f1f7888b8e7c94f2471b..d8bdb4ac6c5eb15743df4ba7aed7f4c72cc5ee03 100644 --- a/ethosu/regor/compiler/operation.cpp +++ b/ethosu/regor/compiler/operation.cpp @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2021-2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2021-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -82,13 +82,17 @@ TensorConnection &Operation::ConnectInput(TensorUsage usage, const std::shared_p // because the existing connection (if present) might be the last remaining reference to this operation. tensor->AddReader(shared_from_this()); - if ( _inputs.contains(usage) && (_inputs[usage].tensor != tensor) ) + TensorConnection &input = _inputs[usage]; + if ( input.tensor != tensor ) { - _inputs[usage].tensor->RemoveReader(shared_from_this()); + if ( input.tensor ) + { + input.tensor->RemoveReader(shared_from_this()); + } + input.tensor = tensor; } - _inputs[usage].tensor = tensor; - _inputs[usage].shape = tensor->StorageShape(); - return _inputs[usage]; + input.shape = tensor->StorageShape(); + return input; } void Operation::DisconnectInputInvalidatingInputs(TensorUsage usage) @@ -115,14 +119,17 @@ TensorConnection &Operation::ConnectOutput(TensorUsage usage, const std::shared_ // because the existing connection (if present) might be the last remaining reference to this operation. tensor->AddWriter(shared_from_this()); - if ( _outputs.contains(usage) && (_outputs[usage].tensor != tensor) ) + TensorConnection &output = _outputs[usage]; + if ( output.tensor != tensor ) { - _outputs[usage].tensor->RemoveWriter(shared_from_this()); + if ( output.tensor ) + { + output.tensor->RemoveWriter(shared_from_this()); + } + output.tensor = tensor; } - _outputs[usage].tensor = tensor; - _outputs[usage].shape = tensor->StorageShape(); - - return _outputs[usage]; + output.shape = tensor->StorageShape(); + return output; } void Operation::Disconnect() diff --git a/ethosu/regor/compiler/operation_util.hpp b/ethosu/regor/compiler/operation_util.hpp index b77e8ed47e61f76a06a02711c8ac947329e79749..f58bacee1033ea7bea860f4caeaae8a57ad69e75 100644 --- a/ethosu/regor/compiler/operation_util.hpp +++ b/ethosu/regor/compiler/operation_util.hpp @@ -315,7 +315,7 @@ inline ArchFM &Set(ArchFM &fm, const TensorConnection *conn) { fm.type = conn->tensor->Type(); fm.shape = conn->shape; - fm.quantization = conn->quantization; + fm.quantization = &conn->quantization; } return fm; } diff --git a/ethosu/regor/compiler/scheduler.cpp b/ethosu/regor/compiler/scheduler.cpp index 62e8bd0aec8ab4295ce691aec96da8be33d5045a..d7017f5a381eff8b5d2e00726c58b36f7ab1b006 100644 --- a/ethosu/regor/compiler/scheduler.cpp +++ b/ethosu/regor/compiler/scheduler.cpp @@ -259,10 +259,14 @@ int Scheduler::UpdateSchedulerTensor(TensorUsage usage, SchedulerConnection *con query.transposeMask = producer->OFM()->transpose; if ( _arch->Constraints()->OperatorQuery(producer->Type(), &query, &req).Any(QueryResult::Native) ) { - if ( (req.req % ArchRequirement::OutputFormat) && req.ofmFormat == TensorFormat::NHWC ) + if ( req.req % ArchRequirement::Tensor ) { - tensor->needsLinearFormat = true; - continue; + auto *tr = Get(&req.tensor, TensorUsage::OFM); + if ( tr && tr->format == TensorFormat::NHWC ) + { + tensor->needsLinearFormat = true; + continue; + } } } } @@ -308,10 +312,10 @@ int Scheduler::UpdateSchedulerTensor(TensorUsage usage, SchedulerConnection *con query.transposeMask = consumer->OFM()->transpose; if ( _arch->Constraints()->OperatorQuery(consumer->Type(), &query, &req).Any(QueryResult::Native) ) { - if ( (req.req % ArchRequirement::InputFormat) ) + if ( req.req % ArchRequirement::Tensor ) { - if ( (usedAs == TensorUsage::IFM0 && req.ifmFormat == TensorFormat::NHWC) || - (usedAs == TensorUsage::IFM1 && req.ifm1Format == TensorFormat::NHWC) ) + auto *tr = Get(&req.tensor, usedAs); + if ( tr && tr->format == TensorFormat::NHWC ) { tensor->needsLinearFormat = true; continue; diff --git a/ethosu/regor/compiler/scheduler_decompose.hpp b/ethosu/regor/compiler/scheduler_decompose.hpp index 7dca942de8afafee110a0a2a0911b0d7fa66a479..e7b55965d405a20d8951cf2cf511a1274a0706d2 100644 --- a/ethosu/regor/compiler/scheduler_decompose.hpp +++ b/ethosu/regor/compiler/scheduler_decompose.hpp @@ -58,7 +58,7 @@ inline ArchFM &Set(ArchFM &fm, const SchedulerConnection *conn) fm.type = conn->Type(); fm.shape = conn->SliceShape(); fm.format = conn->tensor->format; - fm.quantization = conn->quantization; + fm.quantization = &conn->quantization; } return fm; } diff --git a/ethosu/regor/compiler/scheduler_operation.hpp b/ethosu/regor/compiler/scheduler_operation.hpp index 43a0a21a41d1b3537475d2af7a3f5dcce68886f7..3b5ce98f2207a26345950f750e871943b15f3791 100644 --- a/ethosu/regor/compiler/scheduler_operation.hpp +++ b/ethosu/regor/compiler/scheduler_operation.hpp @@ -74,6 +74,13 @@ public: this->uid = GenerateUniqueId(); } + SchedulerTensor(DataType type, const Shape &shape, TensorFormat fmt, const std::shared_ptr &buffer) : + format(fmt), storageShape(shape), dataType(type) + { + this->bufferView = BufferView(buffer, 0, DataTypeStorageSizeBits(type), shape, {}); + this->uid = GenerateUniqueId(); + } + std::shared_ptr Clone() const { auto clone = std::make_shared(*this); diff --git a/ethosu/regor/compiler/scheduler_packing.cpp b/ethosu/regor/compiler/scheduler_packing.cpp index eb2251c15dc47bd2886d38ac3ac96625929f40b0..a1861d3bb2955339dd7e6a3775ec717f64c790ac 100644 --- a/ethosu/regor/compiler/scheduler_packing.cpp +++ b/ethosu/regor/compiler/scheduler_packing.cpp @@ -105,12 +105,14 @@ bool IsConnected(const SchedulerOperation &first, const SchedulerOperation &seco } // namespace SchedulerPacking::SchedulerPacking(Architecture *arch, bool disableChaining) : - _arch(arch), _constraints(arch->Constraints()), _disableChaining(disableChaining) + _arch(arch), _disableChaining(disableChaining) { } std::vector> SchedulerPacking::Process(const Graph *graph) { + _graph = graph; + // Get operation list in execution order std::vector executionList; Graph::TraverseGraphFromEnd(graph->Outputs(), !graph->Persistent().empty(), @@ -120,7 +122,7 @@ std::vector> SchedulerPacking::Process(const return true; }); - FilterOperations(executionList, graph); + ConvertOperations(executionList); PrePackOperations(); @@ -128,38 +130,86 @@ std::vector> SchedulerPacking::Process(const ReorderOperations(); + _graph = nullptr; return std::move(_schedList); } -void SchedulerPacking::FilterOperations(const std::vector &executionList, const Graph *graph) +void SchedulerPacking::ConvertOperation(const Operation *op, std::vector> &result) { - // Convert linear Graph Operations to a list of Scheduler Operations - for ( Operation *op : executionList ) + result.clear(); + + std::unique_ptr schedOp = MakeSchedulerOperation(op); + if ( !schedOp ) { - auto schedOp = MakeSchedulerOperation(op, graph); - if ( !schedOp ) - { - continue; - } + return; + } - if ( ShouldDecompose(_arch, schedOp.get()) ) + // Apply architecture related requirements + ArchRequirements req; + if ( OperatorQuery(_arch, schedOp.get(), &req).Any(QueryResult::HasRequirements) ) + { + // Operator has a list of tensor requirements + if ( req.req.Any(ArchRequirement::Tensor) ) { - auto srcKey = schedOp->_srcKey; - auto schedOps = DecomposeSchedulerOperation(std::move(schedOp)); - // Track source keys - for ( auto &newOp : schedOps ) + const ArchTensorRequirement *tr = &req.tensor; + do { - if ( !newOp->_srcKey ) + SchedulerConnection *conn = schedOp->TryInput(tr->usage); + conn = conn ? conn : schedOp->TryOutput(tr->usage); + if ( conn ) { - newOp->_srcKey = srcKey; + if ( tr->shape ) + { + conn->tensor->storageShape = conn->shape = conn->slice.shape = tr->shape; + } + if ( tr->format != TensorFormat::Unknown ) + { + conn->tensor->format = tr->format; + } + if ( tr->type != DataType::None ) + { + conn->tensor->dataType = tr->type; + } } - } - _schedList.insert( - _schedList.end(), std::make_move_iterator(schedOps.begin()), std::make_move_iterator(schedOps.end())); + else + { + if ( (tr->usage == TensorUsage::Scratch) && tr->shape ) + { + auto scratchTensor = std::make_shared(tr->type, tr->shape, tr->format); + SchedulerConnection *scratchConn = schedOp->ConnectInput(TensorUsage::Scratch0, scratchTensor); + scratchConn->shape = tr->shape; + scratchTensor->memArea = _arch->FeatureMapMemory(); + } + } + tr = tr->next; + } while ( tr ); } - else + } + + result.push_back(std::move(schedOp)); +} + +void SchedulerPacking::ConvertOperations(const std::vector &executionList) +{ + std::vector> converted(4); + + // Convert linear Graph Operations to a list of Scheduler Operations + for ( Operation *op : executionList ) + { + ConvertOperation(op, converted); + + for ( auto &schedOp : converted ) { - _schedList.push_back(std::move(schedOp)); + if ( ShouldDecompose(_arch, schedOp.get()) ) + { + auto schedOps = DecomposeSchedulerOperation(std::move(schedOp)); + _schedList.insert(_schedList.end(), std::make_move_iterator(schedOps.begin()), + std::make_move_iterator(schedOps.end())); + } + else + { + _schedList.push_back(std::move(schedOp)); + } } } } @@ -204,7 +254,7 @@ ArchitectureOpGroupQuery SchedulerPacking::CreateOpGroupQuery(const SchedulerOpe // We handle reinterpret by catching it before we create a SchedulerOperation. // Mapping is modified so that the OFM GraphIR tensor of the preceding OP and // the GraphIR IFM tensor of the succeeding OP map to the same SchedulerTensor. -void SchedulerPacking::HandleReinterpretCast(Operation *op, const Graph *graph) +void SchedulerPacking::HandleReinterpretCast(const Operation *op) { assert(op->Type() == OpType::ReinterpretCast && "Op Type is not ReinterpretCast."); @@ -219,7 +269,7 @@ void SchedulerPacking::HandleReinterpretCast(Operation *op, const Graph *graph) { schedTensor = std::make_shared(); schedTensor->srcTensor = ifmConn->tensor; - InitSchedulerTensor(schedTensor.get(), ifmConn->tensor.get(), graph); + InitSchedulerTensor(schedTensor.get(), ifmConn->tensor.get()); _tensorMap.emplace(ifmConn->tensor.get(), schedTensor); } else @@ -231,22 +281,21 @@ void SchedulerPacking::HandleReinterpretCast(Operation *op, const Graph *graph) // If reinterpret cast is the last OP, that means that it's output tensor is the output tensor of the network. // We therefore set isGraphOutput to true and make sure the srcTensor maps to the graph output tensor. - if ( graph->IsOutput(ofmConn->tensor.get()) ) + if ( _graph->IsOutput(ofmConn->tensor.get()) ) { - InitSchedulerTensor(schedTensor.get(), ofmConn->tensor.get(), graph); + InitSchedulerTensor(schedTensor.get(), ofmConn->tensor.get()); schedTensor->srcTensor = ofmConn->tensor; } } -void SchedulerPacking::SchedulerPacking::PrePackOperations() +void SchedulerPacking::PrePackOperations() { - // Determine if each operation can run on NPU for ( auto &schedOp : _schedList ) { ArchRequirements oReq{}; Flags result = OperatorQuery(_arch, schedOp.get(), &oReq); - // Assert complete query assert(result.Any(QueryResult::Constrained) == false && "Constrained result from complete OperatorQuery"); + // Determine if each operation can run on NPU if ( result.Any(QueryResult::Native) ) { // TODO MLBEDSW-10643: This should be a direct-check against QueryResult::Native @@ -264,10 +313,34 @@ void SchedulerPacking::SchedulerPacking::PrePackOperations() { schedOp->SetNpuOp(false); } + + // Examine elementwise and set a primary path for cascading. + if ( IsBinaryElementwise(schedOp->Type()) ) + { + auto ifm0 = schedOp->Input(TensorUsage::IFM0); + auto ifm1 = schedOp->Input(TensorUsage::IFM1); + auto ofm = schedOp->Output(TensorUsage::OFM); + assert(ifm0 && "Binary elementwise op must have IFM0"); + assert(ifm1 && "Binary elementwise op must have IFM1"); + assert(ofm && "Binary elementwise op must have OFM"); + assert(ifm0->shape.Size() > 0 && "IFM0 must have dimension"); + assert(ifm1->shape.Size() > 0 && "IFM1 must have dimension"); + // Choose the non-const IFM path for binary operations that have + // a constant input on the first IFM + if ( ifm0->tensor->IsConstant() && !ifm1->tensor->IsConstant() ) + { + schedOp->SetPrimaryIfmIndex(1); + } + // Favour the non-broadcast shape for cascading. + else if ( (ifm0->shape != ofm->shape) && (ifm1->shape == ofm->shape) ) + { + schedOp->SetPrimaryIfmIndex(1); + } + } } } -void SchedulerPacking::SchedulerPacking::PackOperations() +void SchedulerPacking::PackOperations() { LOG_TRACE1("Scheduler Packing (of {0} Ops)\n", _schedList.size()); @@ -541,14 +614,13 @@ void SchedulerPacking::InitSchedulerConnection( schedConn->reverse = conn.reverse; schedConn->resamplingMode = ArchResampling::None; schedConn->rounding = conn.rounding; - schedConn->SetType(tensor->dataType); if ( schedConn->slice.stride ) { schedConn->stepXY = schedConn->slice.stride.WH(); } } -void SchedulerPacking::InitSchedulerTensor(SchedulerTensor *schedTensor, Tensor *tensor, const Graph *graph) +void SchedulerPacking::InitSchedulerTensor(SchedulerTensor *schedTensor, Tensor *tensor) { const auto type = tensor->Type(); // Take scheduler-local copies of graph tensor parameters. @@ -557,9 +629,9 @@ void SchedulerPacking::InitSchedulerTensor(SchedulerTensor *schedTensor, Tensor schedTensor->storageShape = Shape::PadAxes(tensor->StorageShape(), 4, 1); schedTensor->dataType = type; schedTensor->bufferView = (IsVariablySized(type) || type == DataType::None) ? BufferView() : tensor->View(); - schedTensor->isGraphInput = graph->IsInput(tensor); - schedTensor->isGraphOutput = graph->IsOutput(tensor); - schedTensor->isPersistent = graph->IsPersistent(tensor); + schedTensor->isGraphInput = _graph->IsInput(tensor); + schedTensor->isGraphOutput = _graph->IsOutput(tensor); + schedTensor->isPersistent = _graph->IsPersistent(tensor); schedTensor->uid = tensor->Uid(); if ( tensor->View().HasBuffer() ) { @@ -578,13 +650,13 @@ void SchedulerPacking::InitSchedulerTensor(SchedulerTensor *schedTensor, Tensor } } -std::unique_ptr SchedulerPacking::MakeSchedulerOperation(Operation *op, const Graph *graph) +std::unique_ptr SchedulerPacking::MakeSchedulerOperation(const Operation *op) { assert(op->Type() != OpType::None); if ( op->Type() == OpType::ReinterpretCast ) { - HandleReinterpretCast(op, graph); + HandleReinterpretCast(op); return nullptr; } @@ -593,7 +665,7 @@ std::unique_ptr SchedulerPacking::MakeSchedulerOperation(Ope schedOp->SetKernel(*op->Kernel()); schedOp->SetHasScaling(op->HasScaling()); schedOp->SetAttributes(op->AttributeRef()); - schedOp->_srcKey = op; + schedOp->_srcKey = reinterpret_cast(reinterpret_cast(op)); // Get the inputs from the source op and connect with scheduler specific tensor for ( const auto *list : {&op->Inputs(), &op->Outputs()} ) @@ -610,7 +682,7 @@ std::unique_ptr SchedulerPacking::MakeSchedulerOperation(Ope auto tmp = std::make_shared(); pos = _tensorMap.emplace(tensor, tmp).first; tmp->srcTensor = item.second.tensor; - InitSchedulerTensor(tmp.get(), tensor, graph); + InitSchedulerTensor(tmp.get(), tensor); } // Update consumers and manage connectivity @@ -651,60 +723,13 @@ std::unique_ptr SchedulerPacking::MakeSchedulerOperation(Ope schedOp->OFM()->transpose = TransposeTypeFromShape(attr->perm); } - // Examine elementwise and set a primary path for cascading. - if ( IsBinaryElementwise(op->Type()) ) - { - auto ifm0 = op->Input(TensorUsage::IFM0); - auto ifm1 = op->Input(TensorUsage::IFM1); - auto ofm = op->Output(TensorUsage::OFM); - assert(ifm0 && "Binary elementwise op must have IFM0"); - assert(ifm1 && "Binary elementwise op must have IFM1"); - assert(ofm && "Binary elementwise op must have OFM"); - assert(ifm0->shape.Size() > 0 && "IFM0 must have dimension"); - assert(ifm1->shape.Size() > 0 && "IFM1 must have dimension"); - // Choose the non-const IFM path for binary operations that have - // a constant input on the first IFM - if ( ifm0->tensor->IsConstant() && !ifm1->tensor->IsConstant() ) - { - schedOp->SetPrimaryIfmIndex(1); - } - // Favour the non-broadcast shape for cascading. - else if ( (ifm0->shape != ofm->shape) && (ifm1->shape == ofm->shape) ) - { - schedOp->SetPrimaryIfmIndex(1); - } - } - - // Check that the Architecture understands what do to with this operator - const auto ofmConn = schedOp->OFM(); - const auto ifm0Conn = schedOp->TryIFM(0); - const auto ifm1Conn = schedOp->TryIFM(1); - ArchOperatorQuery query; - Set(query.ifm[0], ifm0Conn); - Set(query.ifm[1], ifm1Conn); - Set(query.ofm, ofmConn); - query.reverseMask = ofmConn->reverse; - query.transposeMask = ofmConn->transpose; - query.kernel = schedOp->Kernel(); - - ArchRequirements req; - if ( _arch->Constraints()->OperatorQuery(op->Type(), &query, &req).Any(QueryResult::Native) ) - { - // Operator requires a scratch tensor - if ( req.req.Any(ArchRequirement::ScratchTensor) && req.scratch.size ) - { - auto scratchTensor = std::make_shared(req.scratch.type, req.scratch.size, req.scratch.format); - SchedulerConnection *scratchConn = schedOp->ConnectInput(TensorUsage::Scratch0, scratchTensor); - scratchConn->shape = req.scratch.size; - scratchTensor->memArea = _arch->FeatureMapMemory(); - } - } - return schedOp; } std::vector> SchedulerPacking::DecomposeSchedulerOperation(std::unique_ptr op) { + auto srcKey = op->_srcKey; + std::vector> result; ArchRequirements req{}; @@ -768,10 +793,17 @@ std::vector> SchedulerPacking::DecomposeSche } else { + LOG_PRINT("!!!! Can't decompose op:{}", OpTypeToString(op->Type())); assert(false); } break; } + + for ( std::unique_ptr &so : result ) + so->_srcKey = srcKey; + return result; } + + } // namespace regor diff --git a/ethosu/regor/compiler/scheduler_packing.hpp b/ethosu/regor/compiler/scheduler_packing.hpp index 0289d04b7dfaa1c60802f10a6b019e9b004dadba..1781c9ce2757f42a7f73f2e04013dd348524b32f 100644 --- a/ethosu/regor/compiler/scheduler_packing.hpp +++ b/ethosu/regor/compiler/scheduler_packing.hpp @@ -21,7 +21,6 @@ #include "common/common.hpp" #include "common/logging.hpp" -#include "architecture/architecture_constraints.hpp" #include "common/shape.hpp" #include "graph.hpp" #include "operation.hpp" @@ -46,11 +45,11 @@ class SchedulerPacking { protected: Architecture *_arch = nullptr; - IArchitectureConstraints *_constraints = nullptr; bool _disableChaining = false; std::vector> _schedList; std::unordered_map> _tensorMap; std::unordered_map _bufferEquivalenceIdMap; + const Graph *_graph = nullptr; public: SchedulerPacking(Architecture *arch, bool disableChaining); @@ -59,24 +58,20 @@ public: std::vector> Process(const Graph *graph); private: - // Decomposes operations - void FilterOperations(const std::vector &executionList, const Graph *graph); - // Determines NPU/CPU-target + void ConvertOperation(const Operation *op, std::vector> &result); + void ConvertOperations(const std::vector &executionList); void PrePackOperations(); - // Performs fusing/chaining void PackOperations(); - // Reorders CPU-operations void ReorderOperations(); int CanPack(const SchedulerOperation *schedOp, const SchedulerOperation *prevOp, const SchedulerOperation *op, const int prevOpKey) const; void InitSchedulerConnection(SchedulerConnection *schedConn, const std::shared_ptr &tensor, const TensorConnection &conn); - void InitSchedulerTensor(SchedulerTensor *schedTensor, Tensor *tensor, const Graph *graph); - void HandleReinterpretCast(Operation *op, const Graph *graph); - std::unique_ptr MakeSchedulerOperation(Operation *op, const Graph *graph); + void InitSchedulerTensor(SchedulerTensor *schedTensor, Tensor *tensor); + std::unique_ptr MakeSchedulerOperation(const Operation *op); std::vector> DecomposeSchedulerOperation(std::unique_ptr op); ArchResampling ResamplingMode(TensorUsage usage, OpType opType) const; ArchitectureOpGroupQuery CreateOpGroupQuery(const SchedulerOperation *schedOp) const; - ArchOperatorQuery CreateOperatorQuery(const SchedulerOperation *schedOp) const; + void HandleReinterpretCast(const Operation *op); }; } // namespace regor