diff --git a/ethosu/regor/common/shape.hpp b/ethosu/regor/common/shape.hpp index 1a2fa0004d51ffbb24a0cf06c8288fe0865c1988..71a248f8e165fe5e64c13b709281e7190319ab91 100644 --- a/ethosu/regor/common/shape.hpp +++ b/ethosu/regor/common/shape.hpp @@ -556,6 +556,12 @@ public: return true; } + // Returns true if two shapes are equal, ignoring leading dimensions that are 1 + static bool IsReducedEqual(const Shape &a, const Shape &b) + { + return MaxAxisFunc, 1>(a, b) == 0; + } + template int ToNHWC(TYPE *buffer, size_t length) const { @@ -663,6 +669,32 @@ private: return tmp; } + template + static unsigned MaxAxisFunc(const Shape &a, const Shape &b) + { + bool a_longer = a.Size() >= b.Size(); + int length = a_longer ? a.Size() : b.Size(); + assert(length < 32); + int shortest = a_longer ? b.Size() : a.Size(); + assert(shortest < 32); + + auto *pa = a.Storage(); + auto *pb = b.Storage(); + unsigned axisMask = 0; + + int i = 0; + for ( ; i < shortest; i++ ) + { + if ( FUNC()(pa[i], pb[i]) ) axisMask |= 1 << i; + } + for ( ; i < length; i++ ) + { + if ( a_longer && FUNC()(pa[i], MISSING_VALUE) ) axisMask |= 1 << i; + else if ( !a_longer && FUNC()(MISSING_VALUE, pb[i]) ) axisMask |= 1 << i; + } + return axisMask; + } + // Apply a function to the maximum number of axes between two shapes. For missing // axes either take from the longest shape, or substitute a constant value. template diff --git a/ethosu/regor/compiler/graphir_optimiser.cpp b/ethosu/regor/compiler/graphir_optimiser.cpp index eb103f68a9195a7aac133e0ea12f441a08889742..30add616e9b30d03618df6d5531fce80edef7e9d 100644 --- a/ethosu/regor/compiler/graphir_optimiser.cpp +++ b/ethosu/regor/compiler/graphir_optimiser.cpp @@ -2035,7 +2035,6 @@ Operation *GraphIrOptimiser::MoveSplitSliceToConsumer(Graph *const, Operation *c if ( ofm->Readers().size() == 1 ) { auto cons = ofm->Readers().front(); - auto consOfmConn = cons->Output(TensorUsage::OFM); auto *consIfm0 = cons->IFM(0); auto *consIfm1 = cons->IFM(1); @@ -2044,15 +2043,16 @@ Operation *GraphIrOptimiser::MoveSplitSliceToConsumer(Graph *const, Operation *c { // Check if ifm0 consumer has correct shape auto *consIfm0Conn = cons->Input(TensorUsage::IFM0); - ifmShapeEqual = consIfm0Conn->shape == ofmConn->shape; + ifmShapeEqual = Shape::IsReducedEqual(consIfm0Conn->shape, ofmConn->shape); } else if ( consIfm1 != nullptr && consIfm1 == ofm ) { // Check if ifm1 consumer has correct shape auto *consIfm1Conn = cons->Input(TensorUsage::IFM1); - ifmShapeEqual = consIfm1Conn->shape == ofmConn->shape; + ifmShapeEqual = Shape::IsReducedEqual(consIfm1Conn->shape, ofmConn->shape); } + // Calculate the consumer transpose type TransposeType consumerTranspose = TransposeType::None; if ( cons->Type() == OpType::Transpose ) { @@ -2061,7 +2061,7 @@ Operation *GraphIrOptimiser::MoveSplitSliceToConsumer(Graph *const, Operation *c // We can only move to consumer if there is no transpose on the op that we move to, // otherwise the IFM shape may change and transposition will be wrong. - if ( !IsReshape(cons->Type()) && ofmConn->shape == Shape::PadAxes(ofm->StorageShape(), 4, 1) && IsNone(consumerTranspose) && ifmShapeEqual ) + if ( !IsReshape(cons->Type()) && Shape::IsReducedEqual(ofmConn->shape, ofm->StorageShape()) && IsNone(consumerTranspose) && ifmShapeEqual ) { // Split/Slice can be performed by tensor consumer MoveToConsumer(operation, cons.get()); diff --git a/ethosu/regor/test/test_shape.cpp b/ethosu/regor/test/test_shape.cpp index 055413e8ddd946d2399d60ddc8c03cb35f9b7539..735b950d2cf052c75ee9cfac683e207afe4630a5 100644 --- a/ethosu/regor/test/test_shape.cpp +++ b/ethosu/regor/test/test_shape.cpp @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2021, 2023-2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2021, 2023-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -458,3 +458,21 @@ TEST_CASE("To Mask") uint32_t mask2 = shape2.ToMask(); REQUIRE((mask2 == 0x2301)); } + +TEST_CASE("Is reduced equal") +{ + Shape shape1a(3, 3); + Shape shape1b(1, 1, 3, 3); + REQUIRE(Shape::IsReducedEqual(shape1a, shape1b)); + REQUIRE(Shape::IsReducedEqual(shape1b, shape1a)); + + Shape shape3a(3, 3); + Shape shape3b(3, 3, 1); + REQUIRE_FALSE(Shape::IsReducedEqual(shape3a, shape3b)); + REQUIRE_FALSE(Shape::IsReducedEqual(shape3b, shape3a)); + + Shape shape2a(1); + Shape shape2b(1, 1, 1); + REQUIRE(Shape::IsReducedEqual(shape2a, shape2b)); + REQUIRE(Shape::IsReducedEqual(shape2b, shape2a)); +}