From 5078129cab8a8e0be0e513cf510af99f2705ae18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johan=20Alfv=C3=A9n?= Date: Tue, 21 Jan 2025 15:57:21 +0100 Subject: [PATCH] MLBEDSW-10285: MLCE: Output diff caused by wrong ifm box MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - A slice op followed by a conv2d with stride 2 caused an output diff - The slice read is moved to the consumer (conv2d) but the problem in this case was that the ifm box calculation was not correct when having a stride greater than one - The issue is solved by backporting various fixes from Regor that is making sure ifm and ofm box have correct offsets and sizes - Also fixed a hidden problem that read_shape in rewrite_split_ops was calculated erroneously since start and end offset can be less than rank 4 but ifm shape is always 4. That gave a corrupt read_shape. However, read_shape height was not used before this commit so corrupt value was not used and did not cause any problems Change-Id: Ib71c13cfecf77b2cdc2b5aaf437938577c433bb5 Signed-off-by: Johan Alfvén --- ethosu/vela/high_level_command_stream.py | 16 +++++++++++++--- ethosu/vela/tflite_graph_optimiser.py | 6 ++++-- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/ethosu/vela/high_level_command_stream.py b/ethosu/vela/high_level_command_stream.py index bb4ddfec..363bb34b 100644 --- a/ethosu/vela/high_level_command_stream.py +++ b/ethosu/vela/high_level_command_stream.py @@ -99,7 +99,9 @@ class Box: new_start_coord[-2] = max(new_start_coord[-2] * stride - skirt[1], 0) new_end_coord[-2] = min(new_end_coord[-2] * stride + skirt[3], ifm_shape.width) else: - new_start_coord[-2] = max(new_start_coord[-2] * stride - skirt[1], split_offset[-2]) + new_start_coord[-2] = max( + new_start_coord[-2] * stride - (skirt[1] + split_offset[-2]), split_offset[-2] + ) new_end_coord[-2] = min(new_end_coord[-2] * stride + skirt[3], split_offset[-2] + split_shape[-2]) if len(new_start_coord) >= 3: @@ -107,7 +109,15 @@ class Box: skirt_top_remainder = skirt[0] % upscaling_factor total_stride = stride * (new_end_coord[-3] - new_start_coord[-3] - 1) - new_start_coord[-3] = new_start_coord[-3] * stride - skirt[0] + skirt_top_remainder + valid_ifm_height = ifm_shape.height + if split_offset is None: + new_start_coord[-3] = new_start_coord[-3] * stride - skirt[0] + skirt_top_remainder + else: + new_start_coord[-3] = max( + new_start_coord[-3] * stride - (skirt[0] + split_offset[-3]) + skirt_top_remainder, + split_offset[-3], + ) + valid_ifm_height = min(valid_ifm_height, split_offset[-3] + split_shape[-3]) pad_top = max(0, 0 - new_start_coord[-3]) + skirt_top_remainder new_start_coord[-3] = max(new_start_coord[-3], 0) @@ -127,7 +137,7 @@ class Box: # Adjust for upscaling new_start_coord[-3] = max(new_start_coord[-3] // upscaling_factor, 0) new_end_coord[-3] = new_end_coord[-3] * stride + skirt[2] + (skirt[2] % upscaling_factor) - new_end_coord[-3] = max(min(new_end_coord[-3] // upscaling_factor, ifm_shape.height), 1) + new_end_coord[-3] = max(min(new_end_coord[-3] // upscaling_factor, valid_ifm_height), 1) # Wrap the IFMs of broadcasted binary elementwise ops # at the limits of the non-broadcasted volumes diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index 18bc1a23..62afc009 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -152,8 +152,10 @@ def rewrite_split_ops(tens, arch, nng): else: # The read shape is relative to each start offset # Limit read shape to the size of the IFM - offset is not necessarily limited - ifm_dims = split_op.ifm_shapes[0].as_list() - read_shape = Shape4D([min(oe, ifm_dim) - os for oe, os, ifm_dim in zip(offset_end, offset_start, ifm_dims)]) + ifm_dims_4D = split_op.ifm_shapes[0].as_list() + offset_end_4D = Shape4D(offset_end).as_list() + offset_start_4D = Shape4D(offset_start).as_list() + read_shape = Shape4D([min(oe, ifm_dim) - os for oe, os, ifm_dim in zip(offset_end_4D, offset_start_4D, ifm_dims_4D)]) # For Split the offset cannot be extracted from the tensor so it has to # be calculated from the index of the output tensor -- GitLab