diff --git a/ethosu/regor/architecture/architecture.hpp b/ethosu/regor/architecture/architecture.hpp index b517ea2fad6822ea08881d419440bda6fee8e152..0b28ac01a9a27b0b0c5af8e7905404da9a912b99 100644 --- a/ethosu/regor/architecture/architecture.hpp +++ b/ethosu/regor/architecture/architecture.hpp @@ -168,6 +168,7 @@ struct ArchitectureOpGroupQuery TransposeType transpose; ReverseType reverse; bool isConst; + bool isSliced; }; OpType type; diff --git a/ethosu/regor/architecture/ethosu85/ethos_u85.cpp b/ethosu/regor/architecture/ethosu85/ethos_u85.cpp index 494323158f192db1fd856ef6505cf530b9499ea0..70e88b1fd1dedff20b4d8a25f020da5bd97292e2 100644 --- a/ethosu/regor/architecture/ethosu85/ethos_u85.cpp +++ b/ethosu/regor/architecture/ethosu85/ethos_u85.cpp @@ -1459,6 +1459,12 @@ bool EthosU85OpGroup::Fuse(const ArchitectureOpGroupQuery &op, const std::vector return false; } + // Can't fuse transpose to an op with slice + if ( op.type == OpType::Transpose && _ops[0].ofm.isSliced ) + { + return false; + } + EthosU85Constraints *constraints = static_cast(_arch->_constraints.get()); // Can't fuse a transpose or reverse type that's not supported by primaryOp in opgroup diff --git a/ethosu/regor/compiler/scheduler_packing.cpp b/ethosu/regor/compiler/scheduler_packing.cpp index 2030d5a3cb89fc56e62645bed61f332a59776431..64ce65cb2b0cde1d54bff4f01c965f7576f3eb25 100644 --- a/ethosu/regor/compiler/scheduler_packing.cpp +++ b/ethosu/regor/compiler/scheduler_packing.cpp @@ -176,6 +176,7 @@ ArchitectureOpGroupQuery SchedulerPacking::CreateOpGroupQuery(const SchedulerOpe query.ifm[0].transpose = ifm0->transpose; query.ifm[0].reverse = ifm0->reverse; query.ifm[0].isConst = ifm0->tensor->IsConstant(); + query.ifm[0].isSliced = !Shape::IsReducedEqual(ifm0->SliceShape(), ifm0->shape); if ( ifm1 ) { query.ifm[1].key = ifm1->tensor->uid; @@ -184,6 +185,7 @@ ArchitectureOpGroupQuery SchedulerPacking::CreateOpGroupQuery(const SchedulerOpe query.ifm[1].transpose = ifm1->transpose; query.ifm[1].reverse = ifm1->reverse; query.ifm[1].isConst = ifm1->tensor->IsConstant(); + query.ifm[1].isSliced = !Shape::IsReducedEqual(ifm1->SliceShape(), ifm1->shape); } query.ofm.key = ofm->tensor->uid; query.ofm.type = ofm->Type(); @@ -191,6 +193,7 @@ ArchitectureOpGroupQuery SchedulerPacking::CreateOpGroupQuery(const SchedulerOpe query.ofm.transpose = ofm->transpose; query.ofm.reverse = ofm->reverse; query.ofm.isConst = false; + query.ofm.isSliced = !Shape::IsReducedEqual(ofm->SliceShape(), ofm->shape); return query; } @@ -297,6 +300,7 @@ void SchedulerPacking::SchedulerPacking::PackOperations() ofmConn->tensor = nextOp->OFM()->tensor; ofmConn->SetType(nextOp->OFM()->Type()); ofmConn->shape = nextOp->OFM()->shape; + ofmConn->slice = nextOp->OFM()->slice; ofmConn->transpose = nextOp->OFM()->transpose; } else if ( nextOp->Type() == OpType::Reverse ) @@ -305,6 +309,7 @@ void SchedulerPacking::SchedulerPacking::PackOperations() ofmConn->tensor = nextOp->OFM()->tensor; ofmConn->SetType(nextOp->OFM()->Type()); ofmConn->shape = nextOp->OFM()->shape; + ofmConn->slice = nextOp->OFM()->slice; ofmConn->reverse = nextOp->OFM()->reverse; } else