diff --git a/ethosu/regor/test/test_tflite_supported_operators.cpp b/ethosu/regor/test/test_tflite_supported_operators.cpp index ceece6ec322630f56b614e5500417fbb8a2997c5..70ae089a3edc9b7e1f9c39c8e9e039346972c955 100644 --- a/ethosu/regor/test/test_tflite_supported_operators.cpp +++ b/ethosu/regor/test/test_tflite_supported_operators.cpp @@ -346,16 +346,19 @@ TEST_CASE("Supported operators Common") SECTION("ConstraintRsqrt") { - // Rsqrt is only supported with int8 input + // Rsqrt is only supported with int8 or int16 input auto op = CreateOperation(OpType::Rsqrt, Shape(1, 10, 10, 1), DataType::Int8, Shape(1, 10, 10, 1), DataType::Int8); REQUIRE(supportedOps->Check(op.get()) == true); - op->Disconnect(); - for ( auto dtype : {DataType::UInt8, DataType::Int16, DataType::Int32} ) + op->Input(TensorUsage::IFM)->tensor->ChangeType(DataType::Int16); + op->Output(TensorUsage::OFM)->tensor->ChangeType(DataType::Int16); + REQUIRE(supportedOps->Check(op.get()) == true); + for ( auto dtype : {DataType::UInt8, DataType::Int32} ) { - auto op2 = CreateOperation(OpType::Rsqrt, Shape(1, 10, 10, 1), dtype, Shape(1, 10, 10, 1), dtype); - REQUIRE(supportedOps->Check(op2.get()) == false); - op2->Disconnect(); + op->Input(TensorUsage::IFM)->tensor->ChangeType(dtype); + op->Output(TensorUsage::OFM)->tensor->ChangeType(dtype); + REQUIRE(supportedOps->Check(op.get()) == false); } + op->Disconnect(); } SECTION("ConstraintConstParams") diff --git a/ethosu/regor/tflite/tflite_supported_operators.cpp b/ethosu/regor/tflite/tflite_supported_operators.cpp index 2c1a779d6cbef4b0cde7789279537919b51f50af..4b030ae944410108b5e069400940ec0836086fb0 100644 --- a/ethosu/regor/tflite/tflite_supported_operators.cpp +++ b/ethosu/regor/tflite/tflite_supported_operators.cpp @@ -483,12 +483,12 @@ bool TfLiteSupportedOperators::ConstraintTCStrides(const Operation *op) Failure(op, fmt::format("stride out of range: ({},{})", stride.x, stride.y), constraint); return false; } - if ( stride == Point2i(1, 2) && !(ifmShape.Height() == 1 && kh == 1) ) + if ( stride == Point2i(1, 2) && !(ifmShape.Width() == 1 && kw == 1) ) { Failure(op, fmt::format("unsupported stride combination: ({},{})", stride.x, stride.y), constraint); return false; } - if ( stride == Point2i(2, 1) && !(ifmShape.Width() == 1 && kw == 1) ) + if ( stride == Point2i(2, 1) && !(ifmShape.Height() == 1 && kh == 1) ) { Failure(op, fmt::format("unsupported stride combination: ({},{})", stride.x, stride.y), constraint); return false; @@ -536,7 +536,7 @@ bool TfLiteSupportedOperators::ConstraintTCShapes(const Operation *op) } else { - Point2i diff = (kernel->Size() - stride); + Point2i diff = Point2i::Max((kernel->Size() - stride), Point2i(0, 0)); if ( (ifmWH * stride + diff) != ofmWH ) { Failure(op, @@ -559,9 +559,9 @@ bool TfLiteSupportedOperators::ConstraintRsqrt(const Operation *op) auto ifmConn = op->Input(TensorUsage::IFM); assert(ifmConn); auto ifmType = ifmConn->tensor->Type(); - if ( ifmType != DataType::Int8 ) + if ( ifmType != DataType::Int8 && ifmType != DataType::Int16 ) { - Failure(op, fmt::format("{} IFM", DataTypeToString(ifmType)), "IFM must be Int8"); + Failure(op, fmt::format("{} IFM", DataTypeToString(ifmType)), "IFM must be Int8 or Int16"); return false; } return true;