From 9209e7a4f8e6dadeb9a04553331ba7938a6d8969 Mon Sep 17 00:00:00 2001 From: Jacob Bohlin Date: Wed, 30 Jul 2025 14:11:05 +0100 Subject: [PATCH] MLBEDSW-11035 Extend Conv3D and TransposeConv Op queries Decompose properties were not set correctly for Conv3D and TransposeConv. Also did minor refactoring of the OperatorQuery function. Change-Id: I6a665819898776b166df70f0a9a8729d98175e92 Signed-off-by: Jacob Bohlin --- .../ethosu85/ethos_u85_constraints.cpp | 101 ++++++++++++------ 1 file changed, 66 insertions(+), 35 deletions(-) diff --git a/ethosu/regor/architecture/ethosu85/ethos_u85_constraints.cpp b/ethosu/regor/architecture/ethosu85/ethos_u85_constraints.cpp index dfe21622..c774ebfd 100644 --- a/ethosu/regor/architecture/ethosu85/ethos_u85_constraints.cpp +++ b/ethosu/regor/architecture/ethosu85/ethos_u85_constraints.cpp @@ -332,10 +332,52 @@ bool EthosU85Constraints::SupportedZeroPoint(int64_t zp, TensorUsage usage, Data return true; } +// Validate that tensor dimensions are supported +static bool SupportedTensorDims(OpType opType, Shape ifmShape, Shape ifm2Shape, Shape ofmShape) +{ + for ( const auto &s : {ifmShape, ifm2Shape, ofmShape} ) + { + if ( !s ) continue; + auto shape = Shape::PadAxes(s, 4, 1); + // validate that leading dimensions are unit + int leadingDims = opType == OpType::Conv3D ? 1 : ofmShape.Size() - 3; + for ( int i = 0; i < leadingDims; i++ ) + { + if ( shape[i] > 1 ) + { + return false; + } + } + } + return true; +} + +// Validate that tensor axes are supported +static bool SupportedTensorAxes(Shape ifmShape, Shape ifm2Shape, Shape ofmShape) +{ + static constexpr int32_t MAX_AXIS = (1 << 16); + for ( const auto &s : {ifmShape, ifm2Shape, ofmShape} ) + { + if ( !s ) continue; + auto shape = Shape::PadAxes(s, 4, 1); + // validate that HWC are within valid range + for ( int i = shape.Size() - 3; i < shape.Size(); i++ ) + { + if ( shape[i] > MAX_AXIS ) + { + return false; + } + } + } + return true; +} + Flags EthosU85Constraints::OperatorQuery(OpType opType, const ArchOperatorQuery *query, ArchRequirements *req) { Flags result = QueryResult::Native; - static constexpr int32_t MAX_AXIS = (1 << 16); + const auto &ifmShape = query->ifm[0].shape; + const auto &ifm2Shape = query->ifm[1].shape; + const auto &ofmShape = query->ofm.shape; // Check hardware-required substitutions first if ( (opType == OpType::Sigmoid) || (opType == OpType::Tanh) ) @@ -353,6 +395,11 @@ Flags EthosU85Constraints::OperatorQuery(OpType opType, const ArchO { if ( req ) { + // Validate tensor dimensions + if ( !SupportedTensorDims(opType, ifmShape, ifm2Shape, ofmShape) ) + { + req->decomposeProps.Set(ArchProperty::TensorDims); + } req->req.Set(ArchRequirement::Decompose); } return query ? QueryResult::NativeHasReq : QueryResult::NativeConstrainedHasReq; @@ -447,9 +494,6 @@ Flags EthosU85Constraints::OperatorQuery(OpType opType, const ArchO } } - const auto &ifmShape = query->ifm[0].shape; - const auto &ifm2Shape = query->ifm[1].shape; - const auto &ofmShape = query->ofm.shape; const auto ifmType = query->ifm[0].type; const auto ifm2Type = query->ifm[1].type; const auto ofmType = query->ofm.type; @@ -503,40 +547,27 @@ Flags EthosU85Constraints::OperatorQuery(OpType opType, const ArchO return QueryResult::Unsupported; } - // Validate tensor-shapes - if ( shapeInfo ) + + // Validate tensor dimensions + if ( !SupportedTensorDims(opType, ifmShape, ifm2Shape, ofmShape) ) { - for ( const auto &s : {ifmShape, ifm2Shape, ofmShape} ) + if ( req ) { - if ( !s ) continue; - auto shape = Shape::PadAxes(s, 4, 1); - // validate that leading dimensions are unit - for ( int i = 0; i < shape.Size() - 3; i++ ) - { - if ( shape[i] > 1 ) - { - if ( req ) - { - req->req.Set(ArchRequirement::Decompose); - req->decomposeProps.Set(ArchProperty::TensorDims); - } - result.Set(QueryResult::HasRequirements); - } - } - // validate that HWC are within valid range - for ( int i = shape.Size() - 3; i < shape.Size(); i++ ) - { - if ( shape[i] > MAX_AXIS ) - { - if ( req ) - { - req->req.Set(ArchRequirement::Decompose); - req->decomposeProps.Set(ArchProperty::TensorAxis); - } - result.Set(QueryResult::HasRequirements); - } - } + req->req.Set(ArchRequirement::Decompose); + req->decomposeProps.Set(ArchProperty::TensorDims); } + result.Set(QueryResult::HasRequirements); + } + + // Validate tensor axes + if ( !SupportedTensorAxes(ifmShape, ifm2Shape, ofmShape) ) + { + if ( req ) + { + req->req.Set(ArchRequirement::Decompose); + req->decomposeProps.Set(ArchProperty::TensorAxis); + } + result.Set(QueryResult::HasRequirements); } // Detailed operator queries -- GitLab