From 744051f27571730b541e8c4a4e3588404522fab2 Mon Sep 17 00:00:00 2001 From: Philip Hall Date: Tue, 11 Feb 2025 14:54:24 +0000 Subject: [PATCH] MLBEDSW-10351: Fix transpose input shapes and strides - Pad and trim input shapes for Ethos-U55 transpose implementation to 4 axes. - Fix depth-slicing by using correct depths for NWHC ifm and ofm strides. Signed-off-by: Philip Hall Change-Id: Ic9b6aeb259ce9249c402bdb1e4c1308929ce7995 --- .../ethosu55/ethos_u55_register_cs_generator.cpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/ethosu/regor/architecture/ethosu55/ethos_u55_register_cs_generator.cpp b/ethosu/regor/architecture/ethosu55/ethos_u55_register_cs_generator.cpp index 332bb96c..aca04de6 100644 --- a/ethosu/regor/architecture/ethosu55/ethos_u55_register_cs_generator.cpp +++ b/ethosu/regor/architecture/ethosu55/ethos_u55_register_cs_generator.cpp @@ -1296,9 +1296,9 @@ void EthosU55RCSGenerator::InsertTransposeCommand(const HLCStripe *stripe, Tempo assert(op->subOps.empty()); assert(ifm.format == TensorFormat::NHWC); assert(ofm.format == TensorFormat::NHWC); - assert(ifm.shape.Size() <= 4); assert(((ofm.transpose == TransposeType::NWHC) || !ifm.slice.shape || (ifm.shape == ifm.slice.shape)) && "Implementation cannot be sliced"); - ifm.shape = Shape::PadAxes(ifm.shape, 4, 0); + ifm.shape = Shape::PadAxes(ifm.shape, 4, 1); + assert((ifm.shape.AxisProduct(0, ifm.shape.Size() - 3) <= 1) && "Batch transposes unsupported"); Shape outShape = ifm.shape.Permute(unsigned(ofm.transpose)); // Which indexed axes have been swapped @@ -1421,6 +1421,11 @@ void EthosU55RCSGenerator::InsertTransposeCommand(const HLCStripe *stripe, Tempo outFM.strides = Shape(1, elementSize, elementSize * ifm.shape.Width() * ifm.shape.Height(), elementSize); inFM.strides = Shape(1, elementSize * ifm.shape.Width() * ifm.shape.Depth(), elementSize, elementSize); } + else if ( ofm.transpose == TransposeType::NWHC ) + { + outFM.strides = Shape(1, elementSize * ofm.shape.Depth(), elementSize * ofm.shape.Depth() * outFM.shape.Height(), elementSize); + inFM.strides = Shape(1, elementSize * ifm.shape.Depth() * inFM.shape.Width(), elementSize * ifm.shape.Depth(), elementSize); + } else { outFM.strides = Shape(1, elementSize * depth, elementSize * depth * outFM.shape.Height(), elementSize); -- GitLab