From f62b1871f576a24e3d246e88e1af828bfdfbd089 Mon Sep 17 00:00:00 2001 From: Alexander Bengtsson Date: Wed, 26 Mar 2025 17:18:20 +0100 Subject: [PATCH] Fix CPU-fallback regressions for TransposeConv and Rsqrt MLBEDSW-10615: ConstraintTCStrides had inverted checks for uneven strides. MLBEDSW-10611: ConstraintTCShapes should truncate negative size-stride difference to 0. MLBEDSW-10613: Regor supports Rsqrt with Int16 input. Change-Id: I1360e4c15d8dc437d90c593c9477e4ec992d6f47 Signed-off-by: Alexander Bengtsson --- .../test/test_tflite_supported_operators.cpp | 15 +++++++++------ .../regor/tflite/tflite_supported_operators.cpp | 10 +++++----- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/ethosu/regor/test/test_tflite_supported_operators.cpp b/ethosu/regor/test/test_tflite_supported_operators.cpp index ceece6ec..70ae089a 100644 --- a/ethosu/regor/test/test_tflite_supported_operators.cpp +++ b/ethosu/regor/test/test_tflite_supported_operators.cpp @@ -346,16 +346,19 @@ TEST_CASE("Supported operators Common") SECTION("ConstraintRsqrt") { - // Rsqrt is only supported with int8 input + // Rsqrt is only supported with int8 or int16 input auto op = CreateOperation(OpType::Rsqrt, Shape(1, 10, 10, 1), DataType::Int8, Shape(1, 10, 10, 1), DataType::Int8); REQUIRE(supportedOps->Check(op.get()) == true); - op->Disconnect(); - for ( auto dtype : {DataType::UInt8, DataType::Int16, DataType::Int32} ) + op->Input(TensorUsage::IFM)->tensor->ChangeType(DataType::Int16); + op->Output(TensorUsage::OFM)->tensor->ChangeType(DataType::Int16); + REQUIRE(supportedOps->Check(op.get()) == true); + for ( auto dtype : {DataType::UInt8, DataType::Int32} ) { - auto op2 = CreateOperation(OpType::Rsqrt, Shape(1, 10, 10, 1), dtype, Shape(1, 10, 10, 1), dtype); - REQUIRE(supportedOps->Check(op2.get()) == false); - op2->Disconnect(); + op->Input(TensorUsage::IFM)->tensor->ChangeType(dtype); + op->Output(TensorUsage::OFM)->tensor->ChangeType(dtype); + REQUIRE(supportedOps->Check(op.get()) == false); } + op->Disconnect(); } SECTION("ConstraintConstParams") diff --git a/ethosu/regor/tflite/tflite_supported_operators.cpp b/ethosu/regor/tflite/tflite_supported_operators.cpp index 2c1a779d..4b030ae9 100644 --- a/ethosu/regor/tflite/tflite_supported_operators.cpp +++ b/ethosu/regor/tflite/tflite_supported_operators.cpp @@ -483,12 +483,12 @@ bool TfLiteSupportedOperators::ConstraintTCStrides(const Operation *op) Failure(op, fmt::format("stride out of range: ({},{})", stride.x, stride.y), constraint); return false; } - if ( stride == Point2i(1, 2) && !(ifmShape.Height() == 1 && kh == 1) ) + if ( stride == Point2i(1, 2) && !(ifmShape.Width() == 1 && kw == 1) ) { Failure(op, fmt::format("unsupported stride combination: ({},{})", stride.x, stride.y), constraint); return false; } - if ( stride == Point2i(2, 1) && !(ifmShape.Width() == 1 && kw == 1) ) + if ( stride == Point2i(2, 1) && !(ifmShape.Height() == 1 && kh == 1) ) { Failure(op, fmt::format("unsupported stride combination: ({},{})", stride.x, stride.y), constraint); return false; @@ -536,7 +536,7 @@ bool TfLiteSupportedOperators::ConstraintTCShapes(const Operation *op) } else { - Point2i diff = (kernel->Size() - stride); + Point2i diff = Point2i::Max((kernel->Size() - stride), Point2i(0, 0)); if ( (ifmWH * stride + diff) != ofmWH ) { Failure(op, @@ -559,9 +559,9 @@ bool TfLiteSupportedOperators::ConstraintRsqrt(const Operation *op) auto ifmConn = op->Input(TensorUsage::IFM); assert(ifmConn); auto ifmType = ifmConn->tensor->Type(); - if ( ifmType != DataType::Int8 ) + if ( ifmType != DataType::Int8 && ifmType != DataType::Int16 ) { - Failure(op, fmt::format("{} IFM", DataTypeToString(ifmType)), "IFM must be Int8"); + Failure(op, fmt::format("{} IFM", DataTypeToString(ifmType)), "IFM must be Int8 or Int16"); return false; } return true; -- GitLab