From 22f69adb92bc55987fccdbe6d8fe4b753c3602d7 Mon Sep 17 00:00:00 2001 From: Johan Gunnarsson Date: Fri, 24 Jan 2025 13:30:20 +0100 Subject: [PATCH] MLBEDSW-9868: Adjust BufferReader asserts to always allow index 0 Index 0 is unstrided so we can always allow reading that index regardless of stride. Also added a few more asserts. Signed-off-by: Johan Gunnarsson Change-Id: Icf7ac8c46665d4e651232d0c0ba66733dce99dc6 --- ethosu/regor/common/buffer_view.hpp | 9 +++-- .../regor/compiler/tflite_graph_optimiser.cpp | 38 +++++++++---------- 2 files changed, 23 insertions(+), 24 deletions(-) diff --git a/ethosu/regor/common/buffer_view.hpp b/ethosu/regor/common/buffer_view.hpp index 81c2575d..3dc282cd 100644 --- a/ethosu/regor/common/buffer_view.hpp +++ b/ethosu/regor/common/buffer_view.hpp @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2021, 2023-2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2021, 2023-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -410,7 +410,7 @@ public: TYPE operator[](size_t index) const { assert(index < _count); - assert(((_strideBytes.Size() <= 2) || (_count == 1)) && "View does not guarantee linear access"); + assert(((index == 0) || (_strideBytes.Size() <= 2)) && "View does not guarantee linear access"); return _get(_data, index * _strideBytes.Depth()); } @@ -438,6 +438,7 @@ public: using iterator_category = std::forward_iterator_tag; public: + iterator_base_t() = default; iterator_base_t(const iterator_base_t &other) = default; iterator_base_t(GetFunc fn, const void *p, size_t index, size_t strideBytes) : _get(fn), _data(p), _offset(index * strideBytes), _strideBytes(strideBytes) @@ -520,7 +521,7 @@ public: { assert(index < _count); assert(_strideBytes[-1] == sizeof(TYPE)); - assert(((_strideBytes.Size() <= 2) || (_count == 1)) && "View does not guarantee linear access"); + assert(((index == 0) || (_strideBytes.Size() <= 2)) && "View does not guarantee linear access"); return _data[index]; } @@ -528,7 +529,7 @@ public: { assert(index < _count); assert(_strideBytes[-1] == sizeof(TYPE)); - assert(((_strideBytes.Size() <= 2) || (_count == 1)) && "View does not guarantee linear access"); + assert(((index == 0) || (_strideBytes.Size() <= 2)) && "View does not guarantee linear access"); return _data[index]; } diff --git a/ethosu/regor/compiler/tflite_graph_optimiser.cpp b/ethosu/regor/compiler/tflite_graph_optimiser.cpp index 571aee37..0b3e3de9 100644 --- a/ethosu/regor/compiler/tflite_graph_optimiser.cpp +++ b/ethosu/regor/compiler/tflite_graph_optimiser.cpp @@ -2122,25 +2122,24 @@ Operation *TFLiteGraphOptimiser::ConvertPrelu(Graph *const graph, Operation *con int64_t alphaZp = 0; int alphaMin = 0; int alphaMax = 0; - if ( params->tensor->Type() == DataType::Int8 ) + BufferReader reader; + switch ( params->tensor->Type() ) { - auto *alphaBuf = alpha.Buffer()->Data(); - alphaMin = *std::min_element(alphaBuf, alphaBuf + alphaSize); - alphaMax = *std::max_element(alphaBuf, alphaBuf + alphaSize); - } - else if ( params->tensor->Type() == DataType::UInt8 ) - { - auto *alphaBuf = alpha.Buffer()->Data(); - alphaMin = *std::min_element(alphaBuf, alphaBuf + alphaSize); - alphaMax = *std::max_element(alphaBuf, alphaBuf + alphaSize); - } - else if ( params->tensor->Type() == DataType::Int16 ) - { - auto *alphaBuf = alpha.Buffer()->Data(); - alphaMin = *std::min_element(alphaBuf, alphaBuf + alphaSize); - alphaMax = *std::max_element(alphaBuf, alphaBuf + alphaSize); - } - + case DataType::Int8: + reader = alpha.Values(); + break; + case DataType::UInt8: + reader = alpha.Values(); + break; + case DataType::Int16: + reader = alpha.Values(); + break; + default: + assert(false); + }; + auto alphaMinMax = std::minmax_element(reader.begin(), reader.end()); + alphaMin = *alphaMinMax.first; + alphaMax = *alphaMinMax.second; if ( alphaQuant.zeroPoints.size() ) { alphaZp = alphaQuant.zeroPoints[0]; @@ -2163,8 +2162,7 @@ Operation *TFLiteGraphOptimiser::ConvertPrelu(Graph *const graph, Operation *con lreluOp->CopyOutput(TensorUsage::OFM, *ofmConn); auto *attr = lreluOp->Attribute(); attr->alpha = scaledAlphaMin; - // and then optimize LeakyRelU - returnOp = ConvertLeakyRelu(graph, lreluOp.get()); + returnOp = lreluOp.get(); RecordOptimisation(operation, returnOp); operation->Disconnect(); return returnOp; -- GitLab