diff --git a/ethosu/regor/architecture/ethosu55/ethos_u55_constraints.cpp b/ethosu/regor/architecture/ethosu55/ethos_u55_constraints.cpp index 6b3a14569c7b3cb3f60fc9c34ff9389e5b5e408f..8543c36b676755137a8d27a04db33bd3526c9945 100644 --- a/ethosu/regor/architecture/ethosu55/ethos_u55_constraints.cpp +++ b/ethosu/regor/architecture/ethosu55/ethos_u55_constraints.cpp @@ -188,6 +188,7 @@ bool EthosU55Constraints::SupportsRescale(DataType fromType, DataType toType) } +static const std::array s_validAddOfmTypes = {DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32, DataType::Int64}; static const std::array s_validAddMulTypes = {DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32}; static const std::array s_validMaxAbsTypes = {DataType::UInt8, DataType::Int8, DataType::Int16, DataType::Int32}; static const std::array s_validClzShlTypes = {DataType::Int32}; @@ -242,6 +243,11 @@ bool EthosU55Constraints::SupportedDtypes(OpType opType, DataType ifmType, DataT switch ( opType ) { case OpType::Add: + { + ifmTypes = s_validAddMulTypes; + ofmTypes = s_validAddOfmTypes; + } + break; case OpType::Sub: case OpType::Mul: {