From 6499306f50d1af2c6e6678d0b12c0db346071be4 Mon Sep 17 00:00:00 2001 From: Johan Gunnarsson Date: Thu, 12 Dec 2024 18:11:11 +0100 Subject: [PATCH] MLBEDSW-10176: Fix TOSA SELECT regression Revert back to the original TOSA tensor mapping and adjust the TFLite tensor mapping after it instead. Signed-off-by: Johan Gunnarsson Change-Id: Ib4105c9926741a1ac7aa3a22a69113ec98779fe7 --- ethosu/regor/compiler/graphir_optimiser.cpp | 12 ++++++------ ethosu/regor/tflite/tflite_mapping.cpp | 4 ++-- ethosu/regor/tosa/tosa_reader.cpp | 2 +- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/ethosu/regor/compiler/graphir_optimiser.cpp b/ethosu/regor/compiler/graphir_optimiser.cpp index bde354e8..78de0637 100644 --- a/ethosu/regor/compiler/graphir_optimiser.cpp +++ b/ethosu/regor/compiler/graphir_optimiser.cpp @@ -1188,9 +1188,9 @@ Operation *GraphIrOptimiser::RewriteSelect(Graph *const graph, Operation *const const OpType opType = operation->Type(); if ( opType == OpType::Select || opType == OpType::SelectV2 ) { - auto selectorConn = operation->Input(TensorUsage::Params); - const auto ifm0Conn = operation->Input(TensorUsage::IFM0); // Used if selector is true - const auto ifm1Conn = operation->Input(TensorUsage::IFM1); // Used if selector is false + auto selectorConn = operation->Input(TensorUsage::IFM0); + const auto ifm1Conn = operation->Input(TensorUsage::IFM1); // Used if selector is true + const auto ifm2Conn = operation->Input(TensorUsage::IFM2); // Used if selector is false const auto ofmConn = operation->Output(TensorUsage::OFM); // Cast selector IFM (bool8) to same data type as the OFM (if needed) @@ -1204,10 +1204,10 @@ Operation *GraphIrOptimiser::RewriteSelect(Graph *const graph, Operation *const } // Break down SELECT(selector, a, b) into OR(AND(a, selector), AND_NOT(b, selector)) - auto andOp = CreateBinaryElementwise(OpType::And, ifm0Conn->tensor, selectorConn->tensor, - ifm0Conn->quantization, selectorConn->quantization, ofmConn->quantization, ofmConn->tensor->Type()); - auto andNotOp = CreateBinaryElementwise(OpType::AndNot, ifm1Conn->tensor, selectorConn->tensor, + auto andOp = CreateBinaryElementwise(OpType::And, ifm1Conn->tensor, selectorConn->tensor, ifm1Conn->quantization, selectorConn->quantization, ofmConn->quantization, ofmConn->tensor->Type()); + auto andNotOp = CreateBinaryElementwise(OpType::AndNot, ifm2Conn->tensor, selectorConn->tensor, + ifm2Conn->quantization, selectorConn->quantization, ofmConn->quantization, ofmConn->tensor->Type()); auto orOp = CreateBinaryElementwise(OpType::Or, andOp->Output(TensorUsage::OFM)->tensor, andNotOp->Output(TensorUsage::OFM)->tensor, ofmConn->quantization, ofmConn->quantization, ofmConn->quantization, ofmConn->tensor->Type()); diff --git a/ethosu/regor/tflite/tflite_mapping.cpp b/ethosu/regor/tflite/tflite_mapping.cpp index 01ba021b..d29420b7 100644 --- a/ethosu/regor/tflite/tflite_mapping.cpp +++ b/ethosu/regor/tflite/tflite_mapping.cpp @@ -535,12 +535,12 @@ const std::multimap TfLiteMapping::_inputTensorIndices = { {OpType::ScatterNd, TensorUsage::Params}, {OpType::SegmentSum, TensorUsage::IFM0}, {OpType::SegmentSum, TensorUsage::Params}, - {OpType::Select, TensorUsage::Params}, {OpType::Select, TensorUsage::IFM0}, {OpType::Select, TensorUsage::IFM1}, - {OpType::SelectV2, TensorUsage::Params}, + {OpType::Select, TensorUsage::IFM2}, {OpType::SelectV2, TensorUsage::IFM0}, {OpType::SelectV2, TensorUsage::IFM1}, + {OpType::SelectV2, TensorUsage::IFM2}, {OpType::Shape, TensorUsage::IFM0}, {OpType::SignBit, TensorUsage::IFM0}, {OpType::Sin, TensorUsage::IFM0}, diff --git a/ethosu/regor/tosa/tosa_reader.cpp b/ethosu/regor/tosa/tosa_reader.cpp index 0035837e..b1aab0ba 100644 --- a/ethosu/regor/tosa/tosa_reader.cpp +++ b/ethosu/regor/tosa/tosa_reader.cpp @@ -270,7 +270,7 @@ TOSA_REGISTER_OP(LOGICAL_NOT, NONE, GraphAp TOSA_REGISTER_OP(NEGATE, NegateAttribute, GraphApi::GraphTensorUsage::IFM); TOSA_REGISTER_OP(RECIPROCAL, NONE, GraphApi::GraphTensorUsage::IFM); TOSA_REGISTER_OP(RSQRT, NONE, GraphApi::GraphTensorUsage::IFM); -TOSA_REGISTER_OP(SELECT, NONE, GraphApi::GraphTensorUsage::Params, GraphApi::GraphTensorUsage::IFM, GraphApi::GraphTensorUsage::IFM); +TOSA_REGISTER_OP(SELECT, NONE, GraphApi::GraphTensorUsage::IFM, GraphApi::GraphTensorUsage::IFM, GraphApi::GraphTensorUsage::IFM); TOSA_REGISTER_OP(EQUAL, NONE, GraphApi::GraphTensorUsage::IFM, GraphApi::GraphTensorUsage::IFM); TOSA_REGISTER_OP(GREATER, NONE, GraphApi::GraphTensorUsage::IFM, GraphApi::GraphTensorUsage::IFM); TOSA_REGISTER_OP(GREATER_EQUAL, NONE, GraphApi::GraphTensorUsage::IFM, GraphApi::GraphTensorUsage::IFM); -- GitLab