From 30ed7f09fbc6033c68f7e1dc069d9b24a419df7d Mon Sep 17 00:00:00 2001 From: Alexander Bengtsson Date: Tue, 25 Mar 2025 12:58:44 +0100 Subject: [PATCH] MLBEDSW-10574: Align RewriteFullyConnected with shape-changes - Fixes functional regression when FullyConnected with dynamic weights is re-written to Matmul. - MLBEDSW-10418 aligned weight-shapes on TensorConnections with the adjusted storageshape for the weight-tensor. This caused regressions for FullyConnected with dynamic weights as RewriteFullyConnectDynamic assumed 1,1,W,C shape on the TensorConnection. Change-Id: Ifb7dd0f9ee4eced23b80e640143d9ae187c0cbe0 Signed-off-by: Alexander Bengtsson --- ethosu/regor/compiler/tflite_graph_optimiser.cpp | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/ethosu/regor/compiler/tflite_graph_optimiser.cpp b/ethosu/regor/compiler/tflite_graph_optimiser.cpp index 65a9c1e0..4fcc8294 100644 --- a/ethosu/regor/compiler/tflite_graph_optimiser.cpp +++ b/ethosu/regor/compiler/tflite_graph_optimiser.cpp @@ -1464,11 +1464,14 @@ Operation *TFLiteGraphOptimiser::RewriteFullyConnectDynamic(Graph *const, Operat auto ofmShape = Shape::PadAxes(ofm->shape, 4, 1); auto ifmShape = Shape::PadAxes(ifm->shape, 4, 1); - auto ifm2Shape = Shape::PadAxes(ifm2->shape, 4, 1); - - // Add NHCW Transpose op, to convert to GraphIR/TOSA Matmul representation - auto ifm2Reshaped = Shape(ifm2Shape.Batch(), ifm2Shape.Height(), ifm2Shape.Depth(), ifm2Shape.Width()); - auto transposeOp = CreateTransposeForMatMul(ifm2->tensor, ifm2Reshaped); + assert(ifm2->tensor->AxisOrder() == AxisOrder::OHWI); + assert(ifm2->shape.Size() == 4 && "FullyConnected with non-4D weights"); + assert(ifm2->shape.ElementsWH() == 1 && "FullyConnected with non-unit W*H weight-shape"); + + // Add a WC-transpose to convert to GraphIR/TOSA Matmul representation + // ifm2Transposed is both a reshape from N,1,1,C to 1,1,N,C and then a transpose to 1,1,C,N + auto ifm2Transposed = Shape(1, 1, ifm2->shape.Depth(), ifm2->shape.Batch()); + auto transposeOp = CreateTransposeForMatMul(ifm2->tensor, ifm2Transposed); RecordOptimisation(operation, transposeOp); auto ifm2Tensor = transposeOp->Output(TensorUsage::OFM)->tensor; @@ -1476,7 +1479,7 @@ Operation *TFLiteGraphOptimiser::RewriteFullyConnectDynamic(Graph *const, Operat auto rounding = ifm->tensor->Type() == DataType::Int16 ? RoundMode::NATURAL : RoundMode::DBL; matMulOp->ConnectInput(TensorUsage::IFM0, ifm->tensor).Set(ifmShape).Set(ifm->quantization).Set(ifm->slice); - matMulOp->ConnectInput(TensorUsage::IFM1, ifm2Tensor).Set(ifm2Reshaped).Set(ifm2->quantization).Set(ifm2->slice); + matMulOp->ConnectInput(TensorUsage::IFM1, ifm2Tensor).Set(ifm2Transposed).Set(ifm2->quantization).Set(ifm2->slice); matMulOp->ConnectOutput(TensorUsage::OFM, ofm->tensor).Set(ofmShape).Set(ofm->quantization).Set(ofm->slice).Set(rounding); RecordOptimisation(operation, matMulOp.get()); -- GitLab