diff --git a/ethosu/regor/compiler/graphir_optimiser.cpp b/ethosu/regor/compiler/graphir_optimiser.cpp index 962e86aa8a46bc9889888c8d15c40148b3f8ba3a..0390f4730d201e0c06084db10dd59bbad043cbcd 100644 --- a/ethosu/regor/compiler/graphir_optimiser.cpp +++ b/ethosu/regor/compiler/graphir_optimiser.cpp @@ -2057,7 +2057,28 @@ Operation *GraphIrOptimiser::ReshapeReverse(Graph *const graph, Operation *const return returnOp; } -// Reshape ArgMax input/outputs to 3D-tensors where W is the reduced axis +/** + * Rewrite the argmax operation (WxHxC -> WxHx1) to the OpTypes: DepthwiseConv2D -> MaxPool -> LUT -> Cast + * + * DepthwiseConv2D: + * 1x1 kernel with weights that bit shift each value in the input tensor 7 bits to the left. (Weights equal to 1 << 7) + * A bias corresponding channel index in reverse order is added to the kernel to pack channel information into the + * value. [0x04, 0x13, 0x0a, 0x02] = [0000 0100, 0001 0011, 0000 1010, 0000 0010] -> + * [...0010 0000 0000, ...1001 1000 0000, ...0101 0000 0000, ...0001 0000 0000] + [...0011, ...0010, ...0001, ...0000] + * -> + * [...0010 0000 0011, ...1001 1000 0010, ...0101 0000 0001, ...0001 0000 0000] + * + * MaxPool: + * The max pool operation selects the maximum value along the channel and flattens the depth to 1 (WxHxC -> WxHx1) + * [...0010 0000 0011, ...1001 1000 0010, ...0101 0000 0001, ...0001 0000 0000] - > [...1001 1000 0010] + * + * LUT: + * The lookup table is used to retrieve the channel information from the reverse order channel index value in the + * max pool output tensor (the least significant bytes). [0000 1001 1000 0010] -> [0000 0000 0000 0001] = 0x0001 + * + * Cast: + * Finally the value is cast to the correct output type. 0x0001 -> 0x00000001 + */ Operation *GraphIrOptimiser::RewriteArgmax(Graph *const graph, Operation *const operation) { Operation *returnOp = operation; @@ -2071,8 +2092,6 @@ Operation *GraphIrOptimiser::RewriteArgmax(Graph *const graph, Operation *const auto *ofmConn = operation->Output(TensorUsage::OFM); auto &ifmShape = ifmConn->shape; auto &ofmShape = ofmConn->shape; - int ifmRank = ifmConn->shape.Size(); - int axis = attr->axis; // Extend OfmShape to match ifmRank if ( ofmShape.Size() != ifmShape.Size() ) @@ -2081,17 +2100,91 @@ Operation *GraphIrOptimiser::RewriteArgmax(Graph *const graph, Operation *const assert(ofmShape.Size() == ifmShape.Size()); } - // Reshape IFM and OFM to 3D-tensors where W is the reduced axis - if ( attr->axis != 1 || ifmRank != 3 ) + // If native support exists for argmax, we return the argmax op without decomposing. + if ( _constraints->OperatorQuery(OpType::ArgMax).Any(QueryResult::Native) ) { - ifmShape = ReshapeTo3DAroundAxis(ifmShape, axis); - ofmShape = ifmShape.WithWidth(1); - attr->axis = 1; + // Reshape IFM and OFM to 3D-tensors where W is the reduced axis + if ( attr->axis != 1 || ifmConn->shape.Size() != 3 ) + { + ifmShape = ReshapeTo3DAroundAxis(ifmShape, attr->axis); + ofmShape = ifmShape.WithWidth(1); + attr->axis = 1; + } + operation->Output(TensorUsage::OFM)->Set(RoundMode::TRUNCATE_TO_LOWER); + // Update kernel based on reshapes + std::unique_ptr kernel = std::make_unique(Point2i(ifmShape[1], 1), Point2i(1, 1), Point2i(1, 1)); + operation->SetKernel(std::move(kernel)); + + return returnOp; } - operation->Output(TensorUsage::OFM)->Set(RoundMode::TRUNCATE_TO_LOWER); - // Update kernel based on reshapes - std::unique_ptr kernel = std::make_unique(Point2i(ifmShape[1], 1), Point2i(1, 1), Point2i(1, 1)); - operation->SetKernel(std::move(kernel)); + + // Pad both OFM and IFM to 3D before extracting c, w, h + ifmShape = Shape::PadAxes(ifmShape, 3, 1); + ofmShape = Shape::PadAxes(ofmShape, 3, 1); + int c = ifmShape.Depth(); + int w = ifmShape.Width(); + int h = ifmShape.Height(); + + // Create tensors to hold intermediate values for conv, max pool and lut output. + std::shared_ptr convOutput = std::make_shared("convOutputTensor", DataType::Int16, Shape(h, w, c)); + std::shared_ptr maxPoolOutput = std::make_shared("maxPoolOutputTensor", DataType::Int16, Shape(h, w, 1)); + std::shared_ptr lutOutput = std::make_shared("lutOutputTensor", DataType::Int16, Shape(h, w, 1)); + + // Create values for the channel information for the Conv2D bias. + std::vector reverse_idx(c); + std::iota(reverse_idx.begin(), reverse_idx.end(), 0); // reverse_idxs = [0, 1, 2, 3, 4] + std::reverse(reverse_idx.begin(), reverse_idx.end()); // reverse_idxs = [4, 3, 2, 1, 0] + + // Create a constant tensor with the channel information values. + auto biasBuffer = std::make_shared(std::move(reverse_idx)); + Shape biasShape(1, 1, 1, c); + auto convBias = CreateConstTensor("bias", DataType::Int64, biasBuffer, &biasShape); + + // Create weights-tensor with 1x1 kernel. + Shape weightShape(1, 1, 1, c); + std::vector values(weightShape.Elements(), 1 << 7); // Weights are 128 to mimic the bit shift. + auto weightBuf = std::make_shared(std::move(values)); + const auto weightTensor = std::make_shared("convOp_unitWeights", DataType::UInt8, weightShape, weightBuf); + weightTensor->SetAxisOrder(AxisOrder::IHWO); + + // Create a convolution operation for shifting values 7 bits (multiplying by 2**7) + auto convOp = std::make_shared(OpType::DepthwiseConv2D); + convOp->SetKernel(std::make_unique(Point2i(1, 1), Point2i(1, 1), Point2i(1, 1))); + convOp->CopyInput(TensorUsage::IFM, *ifmConn); + convOp->ConnectInput(TensorUsage::Weights, weightTensor).Set(Quantization::Unit()); + // Add bias to the convolution corresponding to the input channel + convOp->ConnectInput(TensorUsage::Scales, convBias).Set(Quantization::Unit()); + convOp->ConnectOutput(TensorUsage::OFM, convOutput).Set(ifmConn->quantization); + + // Max pool op. Squash the width dimension into height since max pool can only work in 2D. 1xHxWxC -> 1x(WxH)xCx1 + auto maxPoolOp = std::make_shared(OpType::MaxPool); + int newHeight = h * w; + maxPoolOp->ConnectInput(TensorUsage::IFM, convOutput).Set(ifmConn->quantization).Set(Shape(1, newHeight, c, 1)); + maxPoolOp->SetKernel(std::make_unique(Point2i(c, 1), Point2i(1, 1), Point2i(1, 1))); + maxPoolOp->ConnectOutput(TensorUsage::OFM, maxPoolOutput).Set(Quantization::Unit()).Set(Shape(1, newHeight, 1, 1)); + + // Create new tensor for LUT + Shape lutShape(1, 1, 1, 512); + uint32_t slope = (-128 & 0xFFFF) << 16; + uint32_t base = c - 1; + std::vector lutValues(512, slope + base); + auto lutBuf = std::make_shared(std::move(lutValues)); + auto lutTensor = CreateConstTensor("lutTensor", DataType::UInt32, lutBuf, &lutShape); + + // Use the LUT operator to extract the channel information from the lower 7 bits + Operation *lutOp = CreateLUT(maxPoolOutput, lutTensor, ifmConn->quantization, ofmConn->quantization, DataType::UInt32); + lutOp->ConnectInput(TensorUsage::IFM, maxPoolOutput).Set(Quantization::Unit()); + lutOp->ConnectOutput(TensorUsage::OFM, lutOutput).Set(Quantization::Unit()).Set(Shape(h, w, 1)); + + // Cast the LUT output back to the OFM type + auto castOp = std::make_shared(OpType::Cast); + castOp->ConnectInput(TensorUsage::IFM, lutOutput).Set(Quantization::Unit()); + castOp->CopyOutput(TensorUsage::OFM, *ofmConn); + + // Set the return OP to the cast and disconnect the input operation before we return from the function + returnOp = castOp.get(); + operation->Disconnect(); + return returnOp; } diff --git a/ethosu/regor/tflite/tflite_supported_operators_u55.cpp b/ethosu/regor/tflite/tflite_supported_operators_u55.cpp index 33eeaf478b6019715625187f4f847daa69fe085b..e10c400d8871f2299bc793c5d32845c34a000961 100644 --- a/ethosu/regor/tflite/tflite_supported_operators_u55.cpp +++ b/ethosu/regor/tflite/tflite_supported_operators_u55.cpp @@ -84,7 +84,7 @@ TfLiteSupportedOperatorsU55::TfLiteSupportedOperatorsU55(IArchitectureConstraint DataType::UInt8, DataType::Int8, DataType::Int16, - DataType::Int32 + DataType::Int32, // clang-format on }; _maxWeightSum8Bit = 127 * (1 << 16); @@ -94,6 +94,9 @@ TfLiteSupportedOperatorsU55::TfLiteSupportedOperatorsU55(IArchitectureConstraint &TfLiteSupportedOperatorsU55::ConstraintBroadcastShapes, &TfLiteSupportedOperatorsU55::ConstraintReverse, &TfLiteSupportedOperatorsU55::Constraint32bitOps, + &TfLiteSupportedOperatorsU55::ConstraintArgMaxDepth, + &TfLiteSupportedOperatorsU55::ConstraintArgMaxAxis, + &TfLiteSupportedOperatorsU55::ConstraintArgMaxOverflow, // TODO: Remove after MLBEDSW-9758: TOSA MaxPool decomp &TfLiteSupportedOperatorsU55::ConstraintKernelStride, &TfLiteSupportedOperatorsU55::ConstraintUnrolledKernelStride, &TfLiteSupportedOperatorsU55::ConstraintMatmul, @@ -203,6 +206,62 @@ bool TfLiteSupportedOperatorsU55::Constraint32bitOps(const Operation *op) return true; } + +// Check that depth is not greater than 127 +bool TfLiteSupportedOperatorsU55::ConstraintArgMaxDepth(const Operation *op) +{ + if ( op->Type() != OpType::ArgMax ) + { + return true; + } + int depth = op->Input(TensorUsage::IFM)->shape.Depth(); + if ( depth > 127 ) + { + Failure(op, fmt::format("The depth of the argmax: {}, is over the limit: 127.", depth)); + return false; + } + return true; +} + +// Check that the operations are performed along the depth axis +bool TfLiteSupportedOperatorsU55::ConstraintArgMaxAxis(const Operation *op) +{ + if ( op->Type() != OpType::ArgMax ) + { + return true; + } + auto axis = op->Attribute()->axis; + const int noAxes = op->Input(TensorUsage::IFM)->shape.Size(); + if ( axis != noAxes - 1 ) + { + Failure(op, fmt::format("The axis of the argmax: {}, is not equal to the index of the depth axis: {} ", axis, noAxes - 1)); + return false; + } + return true; +} + +// TODO: Remove this constraint when MLBEDSW-9758: decomposition for max pooling has been implemented. +bool TfLiteSupportedOperatorsU55::ConstraintArgMaxOverflow(const Operation *op) +{ + if ( op->Type() != OpType::ArgMax ) + { + return true; + } + auto ifmConn = op->Input(TensorUsage::IFM); + assert(ifmConn); + static constexpr int maxProd = 1 << 16; + const auto &ifmShape = ifmConn->shape; + int w = ifmShape.Size() > 1 ? ifmShape.Width() : 1; + int h = ifmShape.Size() > 2 ? ifmShape.Height() : 1; + if ( w * h > maxProd ) + { + Failure(op, fmt::format("ifmShape: ({}), W * H = {}", ifmShape.ToString(), ifmShape.ElementsWH()), + "The product of IFM width and height must be less than 65536"); + return false; + } + return true; +} + bool TfLiteSupportedOperatorsU55::ConstraintKernelStride(const Operation *op) { const auto kernel = op->Kernel(); @@ -322,5 +381,4 @@ bool TfLiteSupportedOperatorsU55::ConstraintMatmul(const Operation *op) } return true; } - } // namespace regor diff --git a/ethosu/regor/tflite/tflite_supported_operators_u55.hpp b/ethosu/regor/tflite/tflite_supported_operators_u55.hpp index 7d626ea129d02d97df8d9edc517da946cc72ff11..2f5fa8ab347e41199e86fe1d85036a9074fc3541 100644 --- a/ethosu/regor/tflite/tflite_supported_operators_u55.hpp +++ b/ethosu/regor/tflite/tflite_supported_operators_u55.hpp @@ -46,5 +46,8 @@ private: bool ConstraintKernelStride(const Operation *op); bool ConstraintUnrolledKernelStride(const Operation *op); bool ConstraintMatmul(const Operation *op); + bool ConstraintArgMaxDepth(const Operation *op); + bool ConstraintArgMaxAxis(const Operation *op); + bool ConstraintArgMaxOverflow(const Operation *op); // TODO: Remove after MLBEDSW-9758: TOSA MaxPool decomp }; } // namespace regor