diff --git a/ethosu/regor/compiler/graphir_optimiser.cpp b/ethosu/regor/compiler/graphir_optimiser.cpp index bde354e8a8cdce04f25018513798a72d9f64d00d..78de0637b869f60098c10e6be89d45a37047a94a 100644 --- a/ethosu/regor/compiler/graphir_optimiser.cpp +++ b/ethosu/regor/compiler/graphir_optimiser.cpp @@ -1188,9 +1188,9 @@ Operation *GraphIrOptimiser::RewriteSelect(Graph *const graph, Operation *const const OpType opType = operation->Type(); if ( opType == OpType::Select || opType == OpType::SelectV2 ) { - auto selectorConn = operation->Input(TensorUsage::Params); - const auto ifm0Conn = operation->Input(TensorUsage::IFM0); // Used if selector is true - const auto ifm1Conn = operation->Input(TensorUsage::IFM1); // Used if selector is false + auto selectorConn = operation->Input(TensorUsage::IFM0); + const auto ifm1Conn = operation->Input(TensorUsage::IFM1); // Used if selector is true + const auto ifm2Conn = operation->Input(TensorUsage::IFM2); // Used if selector is false const auto ofmConn = operation->Output(TensorUsage::OFM); // Cast selector IFM (bool8) to same data type as the OFM (if needed) @@ -1204,10 +1204,10 @@ Operation *GraphIrOptimiser::RewriteSelect(Graph *const graph, Operation *const } // Break down SELECT(selector, a, b) into OR(AND(a, selector), AND_NOT(b, selector)) - auto andOp = CreateBinaryElementwise(OpType::And, ifm0Conn->tensor, selectorConn->tensor, - ifm0Conn->quantization, selectorConn->quantization, ofmConn->quantization, ofmConn->tensor->Type()); - auto andNotOp = CreateBinaryElementwise(OpType::AndNot, ifm1Conn->tensor, selectorConn->tensor, + auto andOp = CreateBinaryElementwise(OpType::And, ifm1Conn->tensor, selectorConn->tensor, ifm1Conn->quantization, selectorConn->quantization, ofmConn->quantization, ofmConn->tensor->Type()); + auto andNotOp = CreateBinaryElementwise(OpType::AndNot, ifm2Conn->tensor, selectorConn->tensor, + ifm2Conn->quantization, selectorConn->quantization, ofmConn->quantization, ofmConn->tensor->Type()); auto orOp = CreateBinaryElementwise(OpType::Or, andOp->Output(TensorUsage::OFM)->tensor, andNotOp->Output(TensorUsage::OFM)->tensor, ofmConn->quantization, ofmConn->quantization, ofmConn->quantization, ofmConn->tensor->Type()); diff --git a/ethosu/regor/tflite/tflite_mapping.cpp b/ethosu/regor/tflite/tflite_mapping.cpp index 01ba021b45736ab16b5d5b5c48ef9069d73f831a..d29420b70462eb1aded2df359fdca4f8d24af2ea 100644 --- a/ethosu/regor/tflite/tflite_mapping.cpp +++ b/ethosu/regor/tflite/tflite_mapping.cpp @@ -535,12 +535,12 @@ const std::multimap TfLiteMapping::_inputTensorIndices = { {OpType::ScatterNd, TensorUsage::Params}, {OpType::SegmentSum, TensorUsage::IFM0}, {OpType::SegmentSum, TensorUsage::Params}, - {OpType::Select, TensorUsage::Params}, {OpType::Select, TensorUsage::IFM0}, {OpType::Select, TensorUsage::IFM1}, - {OpType::SelectV2, TensorUsage::Params}, + {OpType::Select, TensorUsage::IFM2}, {OpType::SelectV2, TensorUsage::IFM0}, {OpType::SelectV2, TensorUsage::IFM1}, + {OpType::SelectV2, TensorUsage::IFM2}, {OpType::Shape, TensorUsage::IFM0}, {OpType::SignBit, TensorUsage::IFM0}, {OpType::Sin, TensorUsage::IFM0}, diff --git a/ethosu/regor/tosa/tosa_reader.cpp b/ethosu/regor/tosa/tosa_reader.cpp index 0035837ee10adeb2ebf394da186d181dbad0aaa3..b1aab0baf9c543a6bcb3eff83dfe3aff511d4a2e 100644 --- a/ethosu/regor/tosa/tosa_reader.cpp +++ b/ethosu/regor/tosa/tosa_reader.cpp @@ -270,7 +270,7 @@ TOSA_REGISTER_OP(LOGICAL_NOT, NONE, GraphAp TOSA_REGISTER_OP(NEGATE, NegateAttribute, GraphApi::GraphTensorUsage::IFM); TOSA_REGISTER_OP(RECIPROCAL, NONE, GraphApi::GraphTensorUsage::IFM); TOSA_REGISTER_OP(RSQRT, NONE, GraphApi::GraphTensorUsage::IFM); -TOSA_REGISTER_OP(SELECT, NONE, GraphApi::GraphTensorUsage::Params, GraphApi::GraphTensorUsage::IFM, GraphApi::GraphTensorUsage::IFM); +TOSA_REGISTER_OP(SELECT, NONE, GraphApi::GraphTensorUsage::IFM, GraphApi::GraphTensorUsage::IFM, GraphApi::GraphTensorUsage::IFM); TOSA_REGISTER_OP(EQUAL, NONE, GraphApi::GraphTensorUsage::IFM, GraphApi::GraphTensorUsage::IFM); TOSA_REGISTER_OP(GREATER, NONE, GraphApi::GraphTensorUsage::IFM, GraphApi::GraphTensorUsage::IFM); TOSA_REGISTER_OP(GREATER_EQUAL, NONE, GraphApi::GraphTensorUsage::IFM, GraphApi::GraphTensorUsage::IFM);