diff --git a/SUPPORTED_OPS.md b/SUPPORTED_OPS.md index d8d86678f03439f7f7966d9f5c2aae1a30329822..332bc8c16b726c9736a01d0965bfdacf11ce28e2 100644 --- a/SUPPORTED_OPS.md +++ b/SUPPORTED_OPS.md @@ -416,7 +416,7 @@ This is a list of constraints that the RSQRT operator must satisfy in order to b - At least one Input's shape must match the OFM's shape - IFM and OFM data types must match -- IFM must be int8 +- IFM must be int8 or uint8 ### Ethos-U55 and Ethos-U65 TFLite SLICE Constraints diff --git a/ethosu/vela/lut.py b/ethosu/vela/lut.py index 7b563b087f5f99ecd46b4e00aa90e29e24197d6f..7024b6e85fd4b6762cbb5fea53c6332148ce8ef2 100644 --- a/ethosu/vela/lut.py +++ b/ethosu/vela/lut.py @@ -1,4 +1,5 @@ # SPDX-FileCopyrightText: Copyright 2020-2021, 2023-2024 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2025 Meta Platforms, Inc. and affiliates. # # SPDX-License-Identifier: Apache-2.0 # @@ -246,7 +247,7 @@ def create_lut_int16_op(op, lut_fn, fn_name): return convert_to_lut(op, lut, fn_name) -def create_lut_rsqrt_int8_op(op): +def create_lut_rsqrt_8bit_op(op): # Turn off black formatting for the LUT tables to keep them compact # fmt: off @@ -301,16 +302,13 @@ def create_lut_rsqrt_int8_op(op): # Shift modification (value used in reference but Vela has opposite sign) kshift = -20 - ix = range(-128, 128) + ix = range(256) if op.ifm.dtype == DataType.uint8 else range(-128, 128) quantized_min = min(ix) quantized_max = max(ix) # Any value close to 0 (zero index in LUT) is mapped to the max output value values = [quantized_max] - for x in ix: - if x == -128: - # Value already populated above - continue + for x in ix[1:]: # Rsqrt is only defined for positive values x_real = max(0, x - zp_in) val = RSQRT_LUT[x_real] diff --git a/ethosu/vela/test/test_tflite_supported_operators.py b/ethosu/vela/test/test_tflite_supported_operators.py index 0c7039b57e1ac1e5034731e35f0f508386a29f47..ef41d1dd0a5847b4fa8d6da20508b0f891fc6515 100644 --- a/ethosu/vela/test/test_tflite_supported_operators.py +++ b/ethosu/vela/test/test_tflite_supported_operators.py @@ -1,4 +1,5 @@ # SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2025 Meta Platforms, Inc. and affiliates. # # SPDX-License-Identifier: Apache-2.0 # @@ -740,9 +741,9 @@ def test_rsqrt_support(): # Test supported op (int8) op = testutil.create_elemwise_op(Op.Rsqrt, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int8) assert support.is_operator_supported(op) - # Test not supported op (uint8) + # Test supported op (uint8) op = testutil.create_elemwise_op(Op.Rsqrt, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.uint8) - assert not support.is_operator_supported(op) + assert support.is_operator_supported(op) # Test not supported op (int16) op = testutil.create_elemwise_op(Op.Rsqrt, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int16) assert not support.is_operator_supported(op) diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index 1834638dfa30bbb805b1b535ea5e822af8ed7996..feae31ae7262deec7fff7d075bdb4da9896312d2 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -1,4 +1,5 @@ # SPDX-FileCopyrightText: Copyright 2020-2025 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2025 Meta Platforms, Inc. and affiliates. # # SPDX-License-Identifier: Apache-2.0 # @@ -46,7 +47,7 @@ from .lstm import Lstm from .lut import convert_to_lut from .lut import create_lut_8bit_op from .lut import create_lut_int16_op -from .lut import create_lut_rsqrt_int8_op +from .lut import create_lut_rsqrt_8bit_op from .numeric_util import clamp_sigmoid from .numeric_util import full_shape from .numeric_util import round_away_zero @@ -2572,7 +2573,7 @@ def convert_ops_to_lut(op: Operation, arch, nng) -> Operation: assert False, f"Unsupported data type {op.ifm.dtype} for {op.type}" if op.type == Op.Rsqrt: - return create_lut_rsqrt_int8_op(op) + return create_lut_rsqrt_8bit_op(op) return op @@ -3008,6 +3009,25 @@ def convert_conv_groups(op: Operation, arch, nng): return op +def rewrite_rsqrt(op: Operation, arch, nng) -> Operation: + if op.type != Op.Rsqrt: + return op + + ifm, ofm = op.get_ifm_ofm() + if ifm.dtype != DataType.int8 or ifm.dtype != ofm.dtype: + return op + + prev_op = ifm.ops[0] + next_op = ofm.consumer_list[0] + + if prev_op.type == Op.Quantize and prev_op.ifm.dtype == DataType.uint8: + op.set_input_tensor(prev_op.ifm, 0) + op.ifm.consumer_list.remove(prev_op) + + if next_op.type == Op.Quantize and next_op.ofm.dtype == DataType.uint8: + op.set_output_tensor(next_op.ofm) + + return op def supported_operator_check(op, arch, nng): op.run_on_npu = arch.tflite_supported_operators.is_operator_supported(op) @@ -3083,6 +3103,16 @@ def tflite_optimise_graph(nng, arch, force_symmetric_int_weights): rewrite_unsupported=False, ) + for idx, sg in enumerate(nng.subgraphs): + nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( + nng, + sg, + arch, + [], + [rewrite_rsqrt], + rewrite_unsupported=False, + ) + # Rewrite of operators op_rewrite_list = [ set_tensor_equivalence, diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py index 1fd284b43155193aa251681afd2af35de0315420..8eba7d8c4d7ff0b948b19caadf641b3717ba2d9a 100644 --- a/ethosu/vela/tflite_supported_operators.py +++ b/ethosu/vela/tflite_supported_operators.py @@ -1,4 +1,5 @@ # SPDX-FileCopyrightText: Copyright 2020-2025 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2025 Meta Platforms, Inc. and affiliates. # # SPDX-License-Identifier: Apache-2.0 # @@ -357,7 +358,7 @@ class TFLiteSupportedOperators: self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_lstm_weight_dimensions) # Rsqrt specific checks - self.specific_constraints[Op.Rsqrt].append(TFLiteSupportedOperators.constraint_rsqrt_input_int8) + self.specific_constraints[Op.Rsqrt].append(TFLiteSupportedOperators.constraint_rsqrt_input_8bit) # Slice specific checks: self.specific_constraints[Op.Slice].append(TFLiteSupportedOperators.constraint_slice_inputs_const) @@ -1061,10 +1062,10 @@ class TFLiteSupportedOperators: return valid, "Op recurrent weights are not 2D" @staticmethod - def constraint_rsqrt_input_int8(op): - "IFM must be int8" + def constraint_rsqrt_input_8bit(op): + "IFM must be int8 or uint8" ifm_dtype = op.ifm.dtype - valid = ifm_dtype == DataType.int8 + valid = ifm_dtype == DataType.int8 or ifm_dtype == DataType.uint8 return valid, f"Op has ifm_dtype={ifm_dtype}" @staticmethod