diff --git a/ethosu/regor/compiler/scheduler_decompose.cpp b/ethosu/regor/compiler/scheduler_decompose.cpp index 357245e98eacb515ae77afe6fb0ec1541d51d440..44032b8314c1270a3ea360d3a3d5179651dedb66 100644 --- a/ethosu/regor/compiler/scheduler_decompose.cpp +++ b/ethosu/regor/compiler/scheduler_decompose.cpp @@ -1279,12 +1279,10 @@ std::vector> DecomposeTransposeConv2D(Archit return result; } -// TODO: Move this to run prior to decomposition. std::vector> LegaliseResize(Architecture *arch, std::unique_ptr op) { - // Convert ResizeBilinear/NearestNeighbor to a number of kernel 1x1 average pools with nearest neighbor x2 upScaling - // and a final average pool with a kernel size that depends upon the resize ops upScaling factor (x2, x4 or x8). The - // maximum upscale factor is limited to x8 because of the limit 8x8 kernel size limit for average pool with padding. + // Convert Resize (Bilinear or Nearest) into a sequence of 1×1 AvgPool ops followed by a final + // larger AvgPool / DepthwiseConv2D with kernel up to 8x8. std::vector> result; @@ -1296,14 +1294,39 @@ std::vector> LegaliseResize(Architecture *ar auto *attr = op->Attribute(); auto upscaleH = attr->scaleY.n; auto upscaleW = attr->scaleX.n; - auto remainingUpscale = std::max(upscaleW, upscaleH); bool canLegalise = true; ArchRequirements req{}; OperatorQuery(arch, op.get(), &req); - auto reqScale = QuantizedScale(1, IntLog2(attr->scaleX.n * attr->scaleY.n)); + auto ofmShape = ofmConn->shape; + auto ifmShape = ifmConn->shape; + + // half_pixel_centers / align_corners pattern match + bool isHalfPixelCenter = false; + if ( attr->scaleY.d == 2 && attr->scaleX.d == 2 && upscaleH % 2 == 0 && upscaleW % 2 == 0 ) + { + upscaleW /= 2; + upscaleH /= 2; + if ( attr->offset.x == -1 * (upscaleW - 1) && attr->offset.y == -1 * (upscaleH - 1) ) + { + isHalfPixelCenter = true; + } + } + auto remainingUpscale = std::max(upscaleW, upscaleH); + + bool isAlignCorners = false; + if ( std::max(float(ofmShape.Height() - 1) / std::max(ifmShape.Height() - 1, 1), + float(ofmShape.Width() - 1) / std::max(ifmShape.Width() - 1, 1)) == remainingUpscale ) + { + if ( !(ifmShape.Height() == 1 && ifmShape.Width() == 1) ) + { + isAlignCorners = true; + } + } + // Transform Gating + // Upscale must be one of 2, 4, or 8. if ( !IsPowerOfTwo(remainingUpscale) || remainingUpscale > 8 || remainingUpscale < 2 ) { canLegalise = false; @@ -1312,9 +1335,23 @@ std::vector> LegaliseResize(Architecture *ar { canLegalise = false; } - else if ( ofmConn->quantization.scales[0] != reqScale ) + else if ( attr->mode == tosa::ResizeMode::BILINEAR ) { - canLegalise = false; + auto reqScale = QuantizedScale(1, IntLog2(attr->scaleX.n * attr->scaleY.n)); + + if ( ofmConn->quantization.scales[0] != reqScale || isHalfPixelCenter || attr->offset.x != 0 || + attr->offset.y != 0 || attr->scaleX.d != 1 || attr->scaleY.d != 1 ) + { + canLegalise = false; + } + } + else if ( attr->mode == tosa::ResizeMode::NEAREST ) + { + // Must be one of align corners or half pixel centers + if ( !(isAlignCorners || isHalfPixelCenter) ) + { + canLegalise = false; + } } if ( !canLegalise ) @@ -1323,11 +1360,10 @@ std::vector> LegaliseResize(Architecture *ar return result; } - auto ofmShape = ofmConn->shape; - auto ifmShape = ifmConn->shape; - - ofmConn->tensor->dataType = ifmConn->tensor->dataType; + ofmConn->tensor->dataType = ifmConn->tensor->dataType; // Force OFM datatype to match IFM ifmConn->resamplingMode = ArchResampling::Nearest; + ifmConn->quantization = Quantization::Unit(); + // Perform 2x upScaling up to the last required while ( remainingUpscale > 2 ) { @@ -1349,13 +1385,65 @@ std::vector> LegaliseResize(Architecture *ar ifmConn->resamplingMode = ArchResampling::Nearest; auto newOp = std::make_unique(OpType::AvgPool); *newOp->ConnectInput(TensorUsage::IFM, ifmConn->tensor) = *ifmConn; - - Kernel kernel = Kernel::UnitKernel().WithPadding({0, 0, upscaleH - 1, upscaleW - 1, 0, 0}).WithSize({upscaleW, upscaleH}); + // Set Kernel + Kernel kernel = Kernel::UnitKernel(); + if ( attr->mode == tosa::ResizeMode::BILINEAR ) + { + if ( !isAlignCorners ) + { + kernel = kernel.WithPadding({0, 0, upscaleH - 1, upscaleW - 1, 0, 0}); + } + kernel = kernel.WithSize({upscaleW, upscaleH}); + } + else + { + if ( isAlignCorners ) + { + newOp = std::make_unique(OpType::DepthwiseConv2D); + *newOp->ConnectInput(TensorUsage::IFM, ifmConn->tensor) = *ifmConn; + // Weights + Shape wShape(ofmShape.Depth(), upscaleH, upscaleW, ofmShape.Depth()); + kernel = kernel.WithSize({upscaleW, upscaleH}); + auto wTensor = std::make_shared(DataType::Int8, wShape); + auto wTensorSrc = std::make_shared("resize_weights", DataType::Int8, wShape); + wTensorSrc->SetAxisOrder(AxisOrder::IHWO); + + const auto wSize = wShape.Elements(); + auto buffer = std::make_shared(std::make_unique(wSize), wSize); + BufferView bufferView(buffer, 0, DataTypeSizeBits(DataType::Int8), wShape, {}); + auto bufferValues = bufferView.WritableValues(); + + const auto h = upscaleH / 2; + const auto w = upscaleW / 2; + for ( int i = 0; i < ofmShape.Depth(); i++ ) + { + for ( int o = 0; o < ofmShape.Depth(); o++ ) + { + bufferValues[{i, h, w, o}] = 1; + } + } + wTensor->srcTensor = std::move(wTensorSrc); + wTensor->memArea = arch->ReadonlyMemory(); + wTensor->bufferView = std::move(bufferView); + newOp->ConnectInput(TensorUsage::Weights, wTensor); + wTensor->storageShape = wTensor->bufferView.ViewShape(); + newOp->Input(TensorUsage::Weights)->quantization = Quantization::Unit(); + // Zero bias + auto biasTensor = std::make_shared(DataType::Int32, Shape(1)); + auto bufBias = std::make_shared(Buffer::ConstValue(0)); + biasTensor->memArea = arch->ReadonlyMemory(); + biasTensor->bufferView = BufferView(bufBias, 0, DataTypeStorageSizeBits(biasTensor->dataType), {1}, {}); + biasTensor->storageShape = biasTensor->bufferView.ViewShape(); + newOp->ConnectInput(TensorUsage::Scales, biasTensor); + } + } newOp->SetKernel(kernel); + ofmConn->quantization = Quantization::Unit(); - ofmConn->rounding = RoundMode::AUTO; + ofmConn->rounding = RoundMode::DBL; *newOp->ConnectOutput(TensorUsage::OFM, ofmConn->tensor) = *ofmConn; result.emplace_back(std::move(newOp)); + return result; } diff --git a/ethosu/regor/tflite/tflite_supported_operators_u55.cpp b/ethosu/regor/tflite/tflite_supported_operators_u55.cpp index 56a53d089ae50bd688fc7e67d296cddd1441a31b..912965e0127994647de0faf93ed857bb5466f56d 100644 --- a/ethosu/regor/tflite/tflite_supported_operators_u55.cpp +++ b/ethosu/regor/tflite/tflite_supported_operators_u55.cpp @@ -69,6 +69,7 @@ TfLiteSupportedOperatorsU55::TfLiteSupportedOperatorsU55(IArchitectureConstraint OpType::ExpandDims, OpType::ReduceSum, OpType::ResizeBilinear, + OpType::ResizeNearestNeighbor, OpType::Rsqrt, OpType::Pack, OpType::Unpack, @@ -149,10 +150,29 @@ bool TfLiteSupportedOperatorsU55::ConstraintBroadcastShapes(const Operation *op) bool TfLiteSupportedOperatorsU55::ConstraintResize(const Operation *op) { - if ( op->Type() != OpType::ResizeBilinear ) + OpType opType = op->Type(); + if ( !(opType == OpType::ResizeBilinear || opType == OpType::ResizeNearestNeighbor) ) { return true; } + bool halfPixelCentersRB = false; + bool alignCorners = false; + const auto *passthrough = static_cast(op->Passthrough()); + assert(passthrough); + + if ( opType == OpType::ResizeBilinear ) + { + const auto *opt = passthrough->builtin_options_as_ResizeBilinearOptions(); + assert(opt); + alignCorners = opt->align_corners(); + halfPixelCentersRB = opt->half_pixel_centers(); + } + else if ( opType == OpType::ResizeNearestNeighbor ) + { + const auto *opt = passthrough->builtin_options_as_ResizeNearestNeighborOptions(); + assert(opt); + alignCorners = opt->align_corners(); + } auto ifmConn = op->Input(TensorUsage::IFM); auto ofmConn = op->Output(TensorUsage::OFM); assert(ifmConn); @@ -169,29 +189,34 @@ bool TfLiteSupportedOperatorsU55::ConstraintResize(const Operation *op) return true; } - const auto *passthrough = static_cast(op->Passthrough()); - assert(passthrough); - const auto *opt = passthrough->builtin_options_as_ResizeBilinearOptions(); - assert(opt); - if ( opt->align_corners() ) + float hUpscale; + float wUpscale; + if ( alignCorners ) { - Failure(op, "Align Corners attribute is true", "Align Corners must be false"); - return false; + hUpscale = ofmShape.Height() == 1 ? 1 : float(ofmShape.Height() - 1) / (ifmShape.Height() - 1); + wUpscale = ofmShape.Width() == 1 ? 1 : float(ofmShape.Width() - 1) / (ifmShape.Width() - 1); + } + else + { + hUpscale = float(ofmShape.Height()) / ifmShape.Height(); + wUpscale = float(ofmShape.Width()) / ifmShape.Width(); } - if ( opt->half_pixel_centers() ) + + if ( halfPixelCentersRB ) { - Failure(op, "Half Pixel Centers attribute is true", "Half Pixel Centers must be false"); + Failure(op, "Half Pixel Centers attribute is true", "Half Pixel Centers must be false for Resize Bilinear"); return false; } std::string constraint = "If not (IFM H == IFM W == 1) and not IFM Shape == OFM Shape\n" "\tIf W upScale != H upScale:\n" "\t\tOFM W or H must be 1, and scaling in the dim that is must also be 1\n" - "\tIF W upScale == H upScale \n" - "\t\tupScale needs to be one of: 2x/4x/8x"; + "\tIf align corners:" + "\t\tupScale is definied as OFM H-1 / IFM H - 1" + "\tElse:" + "\t\tupScale is defined as OFM H/IFM H" + "\tupScale needs to be one of: 2x/4x/8x"; - int hUpscale = ofmShape.Height() / ifmShape.Height(); - int wUpscale = ofmShape.Width() / ifmShape.Width(); if ( hUpscale != wUpscale ) { @@ -204,10 +229,12 @@ bool TfLiteSupportedOperatorsU55::ConstraintResize(const Operation *op) return false; } } - else if ( !((ifmShape.Height() == 1 && ifmShape.Width() == 1) || (ofmShape.Height() % (2 * ifmShape.Height()) == 0 && hUpscale > 1 && hUpscale <= 8)) ) + + auto upscale = std::max(hUpscale, wUpscale); + if ( !((ifmShape.Height() == 1 && ifmShape.Width() == 1) || + (std::trunc(upscale) == upscale && IsPowerOfTwo(int(upscale)) && upscale > 1 && upscale <= 8)) ) { - Failure(op, - fmt::format("Scaling matches and operation has unsupported scaling={}", float(ofmShape.Height()) / ifmShape.Height()), constraint); + Failure(op, fmt::format("Scaling matches and operation has unsupported upScaling={}", upscale), constraint); return false; } return true;