diff --git a/ethosu/regor/architecture/ethosu55/ethos_u55_register_cs_generator.cpp b/ethosu/regor/architecture/ethosu55/ethos_u55_register_cs_generator.cpp index 332bb96c259d07401be6b7ff8218f5c83873586b..aca04de6900b63d15152accf14a9c7edc3cbb309 100644 --- a/ethosu/regor/architecture/ethosu55/ethos_u55_register_cs_generator.cpp +++ b/ethosu/regor/architecture/ethosu55/ethos_u55_register_cs_generator.cpp @@ -1296,9 +1296,9 @@ void EthosU55RCSGenerator::InsertTransposeCommand(const HLCStripe *stripe, Tempo assert(op->subOps.empty()); assert(ifm.format == TensorFormat::NHWC); assert(ofm.format == TensorFormat::NHWC); - assert(ifm.shape.Size() <= 4); assert(((ofm.transpose == TransposeType::NWHC) || !ifm.slice.shape || (ifm.shape == ifm.slice.shape)) && "Implementation cannot be sliced"); - ifm.shape = Shape::PadAxes(ifm.shape, 4, 0); + ifm.shape = Shape::PadAxes(ifm.shape, 4, 1); + assert((ifm.shape.AxisProduct(0, ifm.shape.Size() - 3) <= 1) && "Batch transposes unsupported"); Shape outShape = ifm.shape.Permute(unsigned(ofm.transpose)); // Which indexed axes have been swapped @@ -1421,6 +1421,11 @@ void EthosU55RCSGenerator::InsertTransposeCommand(const HLCStripe *stripe, Tempo outFM.strides = Shape(1, elementSize, elementSize * ifm.shape.Width() * ifm.shape.Height(), elementSize); inFM.strides = Shape(1, elementSize * ifm.shape.Width() * ifm.shape.Depth(), elementSize, elementSize); } + else if ( ofm.transpose == TransposeType::NWHC ) + { + outFM.strides = Shape(1, elementSize * ofm.shape.Depth(), elementSize * ofm.shape.Depth() * outFM.shape.Height(), elementSize); + inFM.strides = Shape(1, elementSize * ifm.shape.Depth() * inFM.shape.Width(), elementSize * ifm.shape.Depth(), elementSize); + } else { outFM.strides = Shape(1, elementSize * depth, elementSize * depth * outFM.shape.Height(), elementSize);