From d964272344c96a151438936bedc9a1eab9a298df Mon Sep 17 00:00:00 2001 From: Alexander Bengtsson Date: Mon, 17 Mar 2025 17:01:16 +0100 Subject: [PATCH] MLBEDSW-10561: Fix ConstraintBias shape checks - Supported-ops constrained bias shapes to 1D. This caused unnecessary CPU-fallback. - Update check to constrain all elements to channel axis instead Change-Id: Ic4e80c6446544d8eb8dcca29cbdb4e1a5e37a4fc Signed-off-by: Alexander Bengtsson --- .../test/test_tflite_supported_operators.cpp | 34 +++++++++---------- .../tflite/tflite_supported_operators.cpp | 7 ++-- 2 files changed, 20 insertions(+), 21 deletions(-) diff --git a/ethosu/regor/test/test_tflite_supported_operators.cpp b/ethosu/regor/test/test_tflite_supported_operators.cpp index 286b2695..9592773a 100644 --- a/ethosu/regor/test/test_tflite_supported_operators.cpp +++ b/ethosu/regor/test/test_tflite_supported_operators.cpp @@ -222,57 +222,57 @@ TEST_CASE("Supported operators Common") SECTION("ConstraintBias") { - auto op = CreateOperation(OpType::DepthwiseConv2D, Shape(1, 5, 5, 1), DataType::Int8, Shape(1, 5, 5, 1), DataType::Int8); - std::vector wValues(1, 1); - auto weights = CreateTensor("weights", Shape(1, 1, 1, 1), DataType::Int8, std::move(wValues)); + auto op = CreateOperation(OpType::DepthwiseConv2D, Shape(1, 5, 5, 2), DataType::Int8, Shape(1, 5, 5, 2), DataType::Int8); + std::vector wValues(2, 1); + auto weights = CreateTensor("weights", Shape(1, 1, 1, 2), DataType::Int8, std::move(wValues)); weights->SetAxisOrder(AxisOrder::IHWO); op->ConnectInput(TensorUsage::Weights, weights).Set(Quantization::Unit()); REQUIRE(supportedOps->Check(op.get()) == true); { // Bias must be const - auto bias = CreateTensor("bias", Shape(1), DataType::Int32); + auto bias = CreateTensor("bias", Shape(1, 1, 1, 2), DataType::Int32); op->ConnectInput(TensorUsage::Scales, bias).Set(Quantization::Unit()); REQUIRE(supportedOps->Check(op.get()) == false); } { - // Bias must be 1D - std::vector values(1, 1); - auto bias = CreateTensor("bias", Shape(1, 1, 1, 1), DataType::Int32, std::move(values)); + // Bias values must be stored in channel-axis + std::vector values(2, 1); + auto bias = CreateTensor("bias", Shape(1, 1, 2, 1), DataType::Int32, std::move(values)); op->ConnectInput(TensorUsage::Scales, bias).Set(Quantization::Unit()); REQUIRE(supportedOps->Check(op.get()) == false); } { // Bias can't be 8bit - std::vector values(1, 1); - auto bias = CreateTensor("bias", Shape(1), DataType::Int8, std::move(values)); + std::vector values(2, 1); + auto bias = CreateTensor("bias", Shape(1, 1, 1, 2), DataType::Int8, std::move(values)); op->ConnectInput(TensorUsage::Scales, bias).Set(Quantization::Unit()); REQUIRE(supportedOps->Check(op.get()) == false); } { // Bias can't be 16bit - std::vector values(1, 1); - auto bias = CreateTensor("bias", Shape(1), DataType::Int16, std::move(values)); + std::vector values(2, 1); + auto bias = CreateTensor("bias", Shape(1, 1, 1, 2), DataType::Int16, std::move(values)); op->ConnectInput(TensorUsage::Scales, bias).Set(Quantization::Unit()); REQUIRE(supportedOps->Check(op.get()) == false); } { // Bias can be 32 bit - std::vector values(1, std::numeric_limits::max()); - auto bias = CreateTensor("bias", Shape(1), DataType::Int32, std::move(values)); + std::vector values(2, std::numeric_limits::max()); + auto bias = CreateTensor("bias", Shape(1, 1, 1, 2), DataType::Int32, std::move(values)); op->ConnectInput(TensorUsage::Scales, bias).Set(Quantization::Unit()); REQUIRE(supportedOps->Check(op.get()) == true); } { // Bias can be 40 bit - std::vector values(1, (1LL << 40) - 1); - auto bias = CreateTensor("bias", Shape(1), DataType::Int64, std::move(values)); + std::vector values(2, (1LL << 40) - 1); + auto bias = CreateTensor("bias", Shape(1, 1, 1, 2), DataType::Int64, std::move(values)); op->ConnectInput(TensorUsage::Scales, bias).Set(Quantization::Unit()); REQUIRE(supportedOps->Check(op.get()) == true); } { // Bias can't be >40 bit - std::vector values(1, std::numeric_limits::max()); - auto bias = CreateTensor("bias", Shape(1), DataType::Int64, std::move(values)); + std::vector values(2, std::numeric_limits::max()); + auto bias = CreateTensor("bias", Shape(1, 1, 1, 2), DataType::Int64, std::move(values)); op->ConnectInput(TensorUsage::Scales, bias).Set(Quantization::Unit()); REQUIRE(supportedOps->Check(op.get()) == false); } diff --git a/ethosu/regor/tflite/tflite_supported_operators.cpp b/ethosu/regor/tflite/tflite_supported_operators.cpp index d7d1f386..54047fe3 100644 --- a/ethosu/regor/tflite/tflite_supported_operators.cpp +++ b/ethosu/regor/tflite/tflite_supported_operators.cpp @@ -356,11 +356,10 @@ bool TfLiteSupportedOperators::ConstraintBias(const Operation *op) { return true; } - int biasDim = bConn->shape.Size(); - - if ( biasDim > 1 ) + auto bShape = bConn->shape; + if ( bShape.Elements() > bShape.Depth() ) { - Failure(op, fmt::format("Operation has {}D bias shape.", biasDim), "The bias tensor shape must be 1D"); + Failure(op, fmt::format("Bias shape: {}", bShape.ToString()), "bias-values must be stored in channel axis"); return false; } if ( !bConn->tensor->IsConstant() ) -- GitLab