From 744b542eebbbaad9d77bf124718962b1eb3e331a Mon Sep 17 00:00:00 2001 From: Jacob Bohlin Date: Thu, 12 Jun 2025 12:46:34 +0100 Subject: [PATCH] MLBEDSW-10739 Add FullyConnected decomposition Decompose FullyConnected operations as Conv2D. Change-Id: If239303901093a49d909e48affe1da961fa63adc Signed-off-by: Jacob Bohlin --- ethosu/regor/compiler/scheduler_decompose.cpp | 1 + ethosu/regor/compiler/scheduler_packing.cpp | 2 ++ ethosu/regor/tflite/tflite_supported_operators.cpp | 12 ------------ 3 files changed, 3 insertions(+), 12 deletions(-) diff --git a/ethosu/regor/compiler/scheduler_decompose.cpp b/ethosu/regor/compiler/scheduler_decompose.cpp index cd203987..19e7b64d 100644 --- a/ethosu/regor/compiler/scheduler_decompose.cpp +++ b/ethosu/regor/compiler/scheduler_decompose.cpp @@ -270,6 +270,7 @@ bool CanDecompose(Architecture *, const SchedulerOperation *schedOp) if ( schedOp->Type() == OpType::AvgPool ) return true; if ( schedOp->Type() == OpType::MaxPool ) return true; if ( schedOp->Type() == OpType::Resize ) return true; + if ( schedOp->Type() == OpType::FullyConnected ) return true; return false; } diff --git a/ethosu/regor/compiler/scheduler_packing.cpp b/ethosu/regor/compiler/scheduler_packing.cpp index 9d0e651c..8d341976 100644 --- a/ethosu/regor/compiler/scheduler_packing.cpp +++ b/ethosu/regor/compiler/scheduler_packing.cpp @@ -738,6 +738,8 @@ std::vector> SchedulerPacking::DecomposeSche switch ( op->Type() ) { + case OpType::FullyConnected: + [[fallthrough]]; case OpType::Conv2D: result = DecomposeConv2D(_arch, std::move(op)); break; diff --git a/ethosu/regor/tflite/tflite_supported_operators.cpp b/ethosu/regor/tflite/tflite_supported_operators.cpp index 1618a396..3e5f505c 100644 --- a/ethosu/regor/tflite/tflite_supported_operators.cpp +++ b/ethosu/regor/tflite/tflite_supported_operators.cpp @@ -207,18 +207,6 @@ bool TfLiteSupportedOperators::ConstraintFCWeightShape(const Operation *op) return false; } - // IC and OC must be smaller than 2^16 - // TODO MLBEDSW-10739: Decompose FullyConnected - if ( shape[0] > (1 << 16) ) - { - Failure(op, fmt::format("Output channels: {}", shape[0]), "Output channels must be less than 2^16"); - return false; - } - if ( shape[-1] > (1 << 16) ) - { - Failure(op, fmt::format("Input channels: {}", shape[-1]), "Input channels must be less than 2^16"); - return false; - } return true; } -- GitLab