diff --git a/ethosu/regor/test/test_tflite_supported_operators.cpp b/ethosu/regor/test/test_tflite_supported_operators.cpp index 286b2695419c4cce223a9f63a68651b7ea02f76b..9592773a7463125ed7869a3b0bc047bc676e6d1f 100644 --- a/ethosu/regor/test/test_tflite_supported_operators.cpp +++ b/ethosu/regor/test/test_tflite_supported_operators.cpp @@ -222,57 +222,57 @@ TEST_CASE("Supported operators Common") SECTION("ConstraintBias") { - auto op = CreateOperation(OpType::DepthwiseConv2D, Shape(1, 5, 5, 1), DataType::Int8, Shape(1, 5, 5, 1), DataType::Int8); - std::vector wValues(1, 1); - auto weights = CreateTensor("weights", Shape(1, 1, 1, 1), DataType::Int8, std::move(wValues)); + auto op = CreateOperation(OpType::DepthwiseConv2D, Shape(1, 5, 5, 2), DataType::Int8, Shape(1, 5, 5, 2), DataType::Int8); + std::vector wValues(2, 1); + auto weights = CreateTensor("weights", Shape(1, 1, 1, 2), DataType::Int8, std::move(wValues)); weights->SetAxisOrder(AxisOrder::IHWO); op->ConnectInput(TensorUsage::Weights, weights).Set(Quantization::Unit()); REQUIRE(supportedOps->Check(op.get()) == true); { // Bias must be const - auto bias = CreateTensor("bias", Shape(1), DataType::Int32); + auto bias = CreateTensor("bias", Shape(1, 1, 1, 2), DataType::Int32); op->ConnectInput(TensorUsage::Scales, bias).Set(Quantization::Unit()); REQUIRE(supportedOps->Check(op.get()) == false); } { - // Bias must be 1D - std::vector values(1, 1); - auto bias = CreateTensor("bias", Shape(1, 1, 1, 1), DataType::Int32, std::move(values)); + // Bias values must be stored in channel-axis + std::vector values(2, 1); + auto bias = CreateTensor("bias", Shape(1, 1, 2, 1), DataType::Int32, std::move(values)); op->ConnectInput(TensorUsage::Scales, bias).Set(Quantization::Unit()); REQUIRE(supportedOps->Check(op.get()) == false); } { // Bias can't be 8bit - std::vector values(1, 1); - auto bias = CreateTensor("bias", Shape(1), DataType::Int8, std::move(values)); + std::vector values(2, 1); + auto bias = CreateTensor("bias", Shape(1, 1, 1, 2), DataType::Int8, std::move(values)); op->ConnectInput(TensorUsage::Scales, bias).Set(Quantization::Unit()); REQUIRE(supportedOps->Check(op.get()) == false); } { // Bias can't be 16bit - std::vector values(1, 1); - auto bias = CreateTensor("bias", Shape(1), DataType::Int16, std::move(values)); + std::vector values(2, 1); + auto bias = CreateTensor("bias", Shape(1, 1, 1, 2), DataType::Int16, std::move(values)); op->ConnectInput(TensorUsage::Scales, bias).Set(Quantization::Unit()); REQUIRE(supportedOps->Check(op.get()) == false); } { // Bias can be 32 bit - std::vector values(1, std::numeric_limits::max()); - auto bias = CreateTensor("bias", Shape(1), DataType::Int32, std::move(values)); + std::vector values(2, std::numeric_limits::max()); + auto bias = CreateTensor("bias", Shape(1, 1, 1, 2), DataType::Int32, std::move(values)); op->ConnectInput(TensorUsage::Scales, bias).Set(Quantization::Unit()); REQUIRE(supportedOps->Check(op.get()) == true); } { // Bias can be 40 bit - std::vector values(1, (1LL << 40) - 1); - auto bias = CreateTensor("bias", Shape(1), DataType::Int64, std::move(values)); + std::vector values(2, (1LL << 40) - 1); + auto bias = CreateTensor("bias", Shape(1, 1, 1, 2), DataType::Int64, std::move(values)); op->ConnectInput(TensorUsage::Scales, bias).Set(Quantization::Unit()); REQUIRE(supportedOps->Check(op.get()) == true); } { // Bias can't be >40 bit - std::vector values(1, std::numeric_limits::max()); - auto bias = CreateTensor("bias", Shape(1), DataType::Int64, std::move(values)); + std::vector values(2, std::numeric_limits::max()); + auto bias = CreateTensor("bias", Shape(1, 1, 1, 2), DataType::Int64, std::move(values)); op->ConnectInput(TensorUsage::Scales, bias).Set(Quantization::Unit()); REQUIRE(supportedOps->Check(op.get()) == false); } diff --git a/ethosu/regor/tflite/tflite_supported_operators.cpp b/ethosu/regor/tflite/tflite_supported_operators.cpp index d7d1f3866f433ebff644a77668529a2b9cb792bf..54047fe37cebd9f17e2edd76d20aee472a38f9f8 100644 --- a/ethosu/regor/tflite/tflite_supported_operators.cpp +++ b/ethosu/regor/tflite/tflite_supported_operators.cpp @@ -356,11 +356,10 @@ bool TfLiteSupportedOperators::ConstraintBias(const Operation *op) { return true; } - int biasDim = bConn->shape.Size(); - - if ( biasDim > 1 ) + auto bShape = bConn->shape; + if ( bShape.Elements() > bShape.Depth() ) { - Failure(op, fmt::format("Operation has {}D bias shape.", biasDim), "The bias tensor shape must be 1D"); + Failure(op, fmt::format("Bias shape: {}", bShape.ToString()), "bias-values must be stored in channel axis"); return false; } if ( !bConn->tensor->IsConstant() )