From acc10645801823cf93189f21246678b5823fa40a Mon Sep 17 00:00:00 2001 From: Alexander Bengtsson Date: Fri, 11 Apr 2025 10:58:17 +0200 Subject: [PATCH] MLBEDSW-10692: Align FC weight constraints with reshape-removal. - Supported-ops for pattern-matching enabled reshape-removal before supported-operator checks. - Align supported-operator checks for FC-weights and asserts in RewriteFullyConnectDynamic with the new behaviour. Change-Id: I9df993d311f9fe588923649d1b8c6dd096926170 Signed-off-by: Alexander Bengtsson --- ethosu/regor/compiler/tflite_graph_optimiser.cpp | 1 - ethosu/regor/test/test_tflite_supported_operators.cpp | 4 +++- ethosu/regor/tflite/tflite_supported_operators.cpp | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/ethosu/regor/compiler/tflite_graph_optimiser.cpp b/ethosu/regor/compiler/tflite_graph_optimiser.cpp index de3d6364..b0758ae2 100644 --- a/ethosu/regor/compiler/tflite_graph_optimiser.cpp +++ b/ethosu/regor/compiler/tflite_graph_optimiser.cpp @@ -1456,7 +1456,6 @@ Operation *TFLiteGraphOptimiser::RewriteFullyConnectDynamic(Graph *const, Operat auto ofmShape = Shape::PadAxes(ofm->shape, 4, 1); auto ifmShape = Shape::PadAxes(ifm->shape, 4, 1); - assert(ifm2->tensor->AxisOrder() == AxisOrder::OHWI); assert(ifm2->shape.Size() == 4 && "FullyConnected with non-4D weights"); assert(ifm2->shape.ElementsWH() == 1 && "FullyConnected with non-unit W*H weight-shape"); diff --git a/ethosu/regor/test/test_tflite_supported_operators.cpp b/ethosu/regor/test/test_tflite_supported_operators.cpp index 7aeff463..056abaca 100644 --- a/ethosu/regor/test/test_tflite_supported_operators.cpp +++ b/ethosu/regor/test/test_tflite_supported_operators.cpp @@ -124,7 +124,9 @@ TEST_CASE("Supported operators Common") weights->SetAxisOrder(AxisOrder::OHWI); op->ConnectInput(TensorUsage::Weights, weights).Set(Quantization::Unit()); REQUIRE(supportedOps->Check(op.get()) == true); - op->Input(TensorUsage::Weights)->tensor->Reshape(Shape(2, 2, 1, 2)); + // reshape and reconnect tensor + weights->Reshape(Shape(2, 2, 1, 2)); + op->ConnectInput(TensorUsage::Weights, weights).Set(Quantization::Unit()); REQUIRE(supportedOps->Check(op.get()) == false); op->Disconnect(); } diff --git a/ethosu/regor/tflite/tflite_supported_operators.cpp b/ethosu/regor/tflite/tflite_supported_operators.cpp index 4890bf7a..079d9fda 100644 --- a/ethosu/regor/tflite/tflite_supported_operators.cpp +++ b/ethosu/regor/tflite/tflite_supported_operators.cpp @@ -199,7 +199,7 @@ bool TfLiteSupportedOperators::ConstraintFCWeightShape(const Operation *op) auto weights = op->Input(TensorUsage::Weights); assert(weights); assert(weights->tensor); - const auto &shape = weights->tensor->StorageShape(); + const auto &shape = weights->shape; // Total elements must be equal to first-dim * last-dim if ( shape.Size() < 2 || (shape.Elements() != (shape[0] * shape[-1])) ) { -- GitLab