From 859c48fe9b9ba400d8fbf5bc55c90ca0e8540651 Mon Sep 17 00:00:00 2001 From: Alexander Bengtsson Date: Fri, 14 Mar 2025 17:05:27 +0100 Subject: [PATCH] MLBEDSW-10414: Port Vela supported-operator checks (5) Add the following supported-operator checks: ConstraintRsqrt: Constrains input-precision to Int8 ConstraintConstParams: Constrains constant parameter-tensors Currently only OpType::Slice is added Change-Id: Ib69a7e8bf0b41e9ed35f235185749ca84eb05c96 Signed-off-by: Alexander Bengtsson --- ethosu/regor/compiler/tensor_properties.hpp | 5 +++ .../test/test_tflite_supported_operators.cpp | 33 ++++++++++++++ .../tflite/tflite_supported_operators.cpp | 43 +++++++++++++++++++ .../tflite/tflite_supported_operators.hpp | 2 + 4 files changed, 83 insertions(+) diff --git a/ethosu/regor/compiler/tensor_properties.hpp b/ethosu/regor/compiler/tensor_properties.hpp index 6ab0cf10..0ae2508f 100644 --- a/ethosu/regor/compiler/tensor_properties.hpp +++ b/ethosu/regor/compiler/tensor_properties.hpp @@ -75,6 +75,11 @@ constexpr inline bool IsIFM(TensorUsage usage) return (usage & TensorUsage::TypeMask) == TensorUsage::IFM; } +constexpr inline bool IsParams(TensorUsage usage) +{ + return (usage & TensorUsage::TypeMask) == TensorUsage::Params; +} + template constexpr inline TensorUsage MakeTensorUsage(TensorUsage type, NUMERIC index) { diff --git a/ethosu/regor/test/test_tflite_supported_operators.cpp b/ethosu/regor/test/test_tflite_supported_operators.cpp index 9592773a..a8c6db60 100644 --- a/ethosu/regor/test/test_tflite_supported_operators.cpp +++ b/ethosu/regor/test/test_tflite_supported_operators.cpp @@ -343,6 +343,39 @@ TEST_CASE("Supported operators Common") REQUIRE(supportedOps->Check(op.get()) == true); op->Disconnect(); } + + SECTION("ConstraintRsqrt") + { + // Rsqrt is only supported with int8 input + auto op = CreateOperation(OpType::Rsqrt, Shape(1, 10, 10, 1), DataType::Int8, Shape(1, 10, 10, 1), DataType::Int8); + REQUIRE(supportedOps->Check(op.get()) == true); + op->Disconnect(); + for ( auto dtype : {DataType::UInt8, DataType::Int16, DataType::Int32} ) + { + auto op2 = CreateOperation(OpType::Rsqrt, Shape(1, 10, 10, 1), dtype, Shape(1, 10, 10, 1), dtype); + REQUIRE(supportedOps->Check(op2.get()) == false); + op2->Disconnect(); + } + } + + SECTION("ConstraintConstParams") + { + auto op = CreateOperation(OpType::Slice, Shape(1, 10, 10, 1), DataType::Int8, Shape(1, 10, 10, 1), DataType::Int8); + auto begin = CreateTensor("begin", Shape(4), DataType::Int32); + auto slice = CreateTensor("slice", Shape(4), DataType::Int32); + // validate parameter-tensors can't be dynamic + op->ConnectInput(TensorUsage::Params0, begin); + op->ConnectInput(TensorUsage::Params1, slice); + REQUIRE(supportedOps->Check(op.get()) == false); + + // validate parameter-tensors can be const + begin = CreateTensor("begin", Shape(4), DataType::Int32, std::vector(4, 1)); + slice = CreateTensor("slice", Shape(4), DataType::Int32, std::vector(4, 1)); + op->ConnectInput(TensorUsage::Params0, begin); + op->ConnectInput(TensorUsage::Params1, slice); + REQUIRE(supportedOps->Check(op.get()) == true); + op->Disconnect(); + } } TEST_CASE("Supported operators EthosU55") diff --git a/ethosu/regor/tflite/tflite_supported_operators.cpp b/ethosu/regor/tflite/tflite_supported_operators.cpp index 1485da80..a27463f5 100644 --- a/ethosu/regor/tflite/tflite_supported_operators.cpp +++ b/ethosu/regor/tflite/tflite_supported_operators.cpp @@ -551,6 +551,47 @@ bool TfLiteSupportedOperators::ConstraintTCShapes(const Operation *op) return true; } +bool TfLiteSupportedOperators::ConstraintRsqrt(const Operation *op) +{ + OpType opType = op->Type(); + if ( opType != OpType::Rsqrt ) + { + return true; + } + auto ifmConn = op->Input(TensorUsage::IFM); + assert(ifmConn); + auto ifmType = ifmConn->tensor->Type(); + if ( ifmType != DataType::Int8 ) + { + Failure(op, fmt::format("{} IFM", DataTypeToString(ifmType)), "IFM must be Int8"); + return false; + } + return true; +} + +bool TfLiteSupportedOperators::ConstraintConstParams(const Operation *op) +{ + OpType opType = op->Type(); + if ( opType != OpType::Slice ) + { + return true; + } + + for ( const auto item : op->Inputs().pairs() ) + { + auto usage = item.first; + auto &conn = item.second; + if ( IsParams(usage) && !conn.tensor->IsConstant() ) + { + assert(conn.tensor); + Failure(op, fmt::format("non-constant tensor {}", conn.tensor->Name()), "Parameter tensors must be constant"); + return false; + } + } + + return true; +} + void TfLiteSupportedOperators::Failure(const Operation *op, const std::string &message, const std::string &constraint) { assert(op); @@ -597,6 +638,8 @@ TfLiteSupportedOperators::TfLiteSupportedOperators(IArchitectureConstraints *con &TfLiteSupportedOperators::ConstraintMaxPool, &TfLiteSupportedOperators::ConstraintTCStrides, &TfLiteSupportedOperators::ConstraintTCShapes, + &TfLiteSupportedOperators::ConstraintRsqrt, + &TfLiteSupportedOperators::ConstraintConstParams, }; } diff --git a/ethosu/regor/tflite/tflite_supported_operators.hpp b/ethosu/regor/tflite/tflite_supported_operators.hpp index cd31817b..8473eb1c 100644 --- a/ethosu/regor/tflite/tflite_supported_operators.hpp +++ b/ethosu/regor/tflite/tflite_supported_operators.hpp @@ -67,5 +67,7 @@ private: bool ConstraintMaxPool(const Operation *op); bool ConstraintTCStrides(const Operation *op); bool ConstraintTCShapes(const Operation *op); + bool ConstraintRsqrt(const Operation *op); + bool ConstraintConstParams(const Operation *op); }; } // namespace regor -- GitLab