From 72c32a332c5622b61de9477cf51b788f08b56542 Mon Sep 17 00:00:00 2001 From: Johan Gunnarsson Date: Tue, 6 May 2025 12:31:43 +0200 Subject: [PATCH] MLBEDSW-10533: Don't fuse transpose to an op with OFM slice * Don't fuse transpose to an op with OFM slice * Also, when fusing a transpose, the primary op should inherit the fused op's OFM slice. Otherwise we might end up with different shapes on OFM slice and OFM and in that case OFM slice shape will be used later on. Signed-off-by: Johan Gunnarsson Change-Id: Idb25cc3a53f0b52dcc59cda1aefa31d9d19a850f --- ethosu/regor/architecture/architecture.hpp | 1 + ethosu/regor/architecture/ethosu85/ethos_u85.cpp | 6 ++++++ ethosu/regor/compiler/scheduler_packing.cpp | 5 +++++ 3 files changed, 12 insertions(+) diff --git a/ethosu/regor/architecture/architecture.hpp b/ethosu/regor/architecture/architecture.hpp index b517ea2f..0b28ac01 100644 --- a/ethosu/regor/architecture/architecture.hpp +++ b/ethosu/regor/architecture/architecture.hpp @@ -168,6 +168,7 @@ struct ArchitectureOpGroupQuery TransposeType transpose; ReverseType reverse; bool isConst; + bool isSliced; }; OpType type; diff --git a/ethosu/regor/architecture/ethosu85/ethos_u85.cpp b/ethosu/regor/architecture/ethosu85/ethos_u85.cpp index 49432315..70e88b1f 100644 --- a/ethosu/regor/architecture/ethosu85/ethos_u85.cpp +++ b/ethosu/regor/architecture/ethosu85/ethos_u85.cpp @@ -1459,6 +1459,12 @@ bool EthosU85OpGroup::Fuse(const ArchitectureOpGroupQuery &op, const std::vector return false; } + // Can't fuse transpose to an op with slice + if ( op.type == OpType::Transpose && _ops[0].ofm.isSliced ) + { + return false; + } + EthosU85Constraints *constraints = static_cast(_arch->_constraints.get()); // Can't fuse a transpose or reverse type that's not supported by primaryOp in opgroup diff --git a/ethosu/regor/compiler/scheduler_packing.cpp b/ethosu/regor/compiler/scheduler_packing.cpp index 2030d5a3..64ce65cb 100644 --- a/ethosu/regor/compiler/scheduler_packing.cpp +++ b/ethosu/regor/compiler/scheduler_packing.cpp @@ -176,6 +176,7 @@ ArchitectureOpGroupQuery SchedulerPacking::CreateOpGroupQuery(const SchedulerOpe query.ifm[0].transpose = ifm0->transpose; query.ifm[0].reverse = ifm0->reverse; query.ifm[0].isConst = ifm0->tensor->IsConstant(); + query.ifm[0].isSliced = !Shape::IsReducedEqual(ifm0->SliceShape(), ifm0->shape); if ( ifm1 ) { query.ifm[1].key = ifm1->tensor->uid; @@ -184,6 +185,7 @@ ArchitectureOpGroupQuery SchedulerPacking::CreateOpGroupQuery(const SchedulerOpe query.ifm[1].transpose = ifm1->transpose; query.ifm[1].reverse = ifm1->reverse; query.ifm[1].isConst = ifm1->tensor->IsConstant(); + query.ifm[1].isSliced = !Shape::IsReducedEqual(ifm1->SliceShape(), ifm1->shape); } query.ofm.key = ofm->tensor->uid; query.ofm.type = ofm->Type(); @@ -191,6 +193,7 @@ ArchitectureOpGroupQuery SchedulerPacking::CreateOpGroupQuery(const SchedulerOpe query.ofm.transpose = ofm->transpose; query.ofm.reverse = ofm->reverse; query.ofm.isConst = false; + query.ofm.isSliced = !Shape::IsReducedEqual(ofm->SliceShape(), ofm->shape); return query; } @@ -297,6 +300,7 @@ void SchedulerPacking::SchedulerPacking::PackOperations() ofmConn->tensor = nextOp->OFM()->tensor; ofmConn->SetType(nextOp->OFM()->Type()); ofmConn->shape = nextOp->OFM()->shape; + ofmConn->slice = nextOp->OFM()->slice; ofmConn->transpose = nextOp->OFM()->transpose; } else if ( nextOp->Type() == OpType::Reverse ) @@ -305,6 +309,7 @@ void SchedulerPacking::SchedulerPacking::PackOperations() ofmConn->tensor = nextOp->OFM()->tensor; ofmConn->SetType(nextOp->OFM()->Type()); ofmConn->shape = nextOp->OFM()->shape; + ofmConn->slice = nextOp->OFM()->slice; ofmConn->reverse = nextOp->OFM()->reverse; } else -- GitLab