diff --git a/ethosu/regor/tflite/tflite_supported_operators_u55.cpp b/ethosu/regor/tflite/tflite_supported_operators_u55.cpp index 62f861c810244096fcc5523889b5a932102c0929..4b75720110ad09e79db920c795d10466199420d0 100644 --- a/ethosu/regor/tflite/tflite_supported_operators_u55.cpp +++ b/ethosu/regor/tflite/tflite_supported_operators_u55.cpp @@ -96,6 +96,7 @@ TfLiteSupportedOperatorsU55::TfLiteSupportedOperatorsU55(IArchitectureConstraint &TfLiteSupportedOperatorsU55::Constraint32bitOps, &TfLiteSupportedOperatorsU55::ConstraintKernelStride, &TfLiteSupportedOperatorsU55::ConstraintUnrolledKernelStride, + &TfLiteSupportedOperatorsU55::ConstraintMatmul, }; } @@ -256,4 +257,66 @@ bool TfLiteSupportedOperatorsU55::ConstraintUnrolledKernelStride(const Operation return true; } +bool TfLiteSupportedOperatorsU55::ConstraintMatmul(const Operation *op) +{ + OpType opType = op->Type(); + if ( opType != OpType::BatchMatMul && opType != OpType::FullyConnected ) + { + return true; + } + auto ifmConn = op->Input(TensorUsage::IFM0); + auto ofmConn = op->Output(TensorUsage::OFM); + assert(ifmConn); + assert(ofmConn); + auto ifmShape = ifmConn->shape; + auto ofmShape = ofmConn->shape; + + bool adj_x = false; + if ( opType == OpType::FullyConnected ) + { + auto wConn = op->Input(TensorUsage::Weights); + assert(wConn); + if ( wConn->tensor->IsConstant() ) + { + // Non-dynamic weights, not a matmul + return true; + } + } + else + { + const tflite::Operator *const passthrough = static_cast(op->Passthrough()); + const auto options = passthrough->builtin_options_as_BatchMatMulOptions(); + if ( options ) + { + adj_x = options->adj_x(); + } + } + + if ( adj_x ) + { + // NHWC-transpose ifm-shape + ifmShape = ifmShape.Permute(0x3201); + } + // OFM-depth and the reduced axis (ifmShape.Depth()) is constrained to 16-bits + const static int maxAxis = 1 << 16; + if ( ifmShape.Depth() > maxAxis ) + { + static const std::string constraint = fmt::format("The reduced axis must be less than or equal to {}", maxAxis); + Failure(op, fmt::format("The reduced Axis is: {}", ifmShape.Depth()), constraint); + return false; + } + if ( ofmShape.Depth() > maxAxis ) + { + static const std::string constraint = fmt::format("The OFM depth must be less than or equal to {}", maxAxis); + Failure(op, fmt::format("OFM channel: {}", ofmShape.Depth()), constraint); + return false; + } + if ( ifmConn->tensor->Type() != DataType::Int8 ) + { + Failure(op, fmt::format("IFM has datatype: {}", DataTypeToString(ifmConn->tensor->Type())), "IFM must be Int8"); + return false; + } + return true; +} + } // namespace regor diff --git a/ethosu/regor/tflite/tflite_supported_operators_u55.hpp b/ethosu/regor/tflite/tflite_supported_operators_u55.hpp index 1d14f40370e2a56f76852662ac37db3abaf12285..7d626ea129d02d97df8d9edc517da946cc72ff11 100644 --- a/ethosu/regor/tflite/tflite_supported_operators_u55.hpp +++ b/ethosu/regor/tflite/tflite_supported_operators_u55.hpp @@ -45,5 +45,6 @@ private: bool Constraint32bitOps(const Operation *op); bool ConstraintKernelStride(const Operation *op); bool ConstraintUnrolledKernelStride(const Operation *op); + bool ConstraintMatmul(const Operation *op); }; } // namespace regor