From 7e54e6c6e3f8f08ea4145979416d7563b1754650 Mon Sep 17 00:00:00 2001 From: Philip Hall Date: Wed, 29 Jan 2025 18:12:43 +0000 Subject: [PATCH] MLBEDSW-10106: Ethos-U55 MatMul implementation Implementation of TOSA MatMul for Ethos-U55. - Added scratch tensor return to operator queries. - Allowed for differing input/output tensor formats. - Cleanup of tensor format assignment. Signed-off-by: Philip Hall Change-Id: I14cfa158bd0843ce6ff1516dfdb880f7100cc94a --- ethosu/regor/architecture/architecture.hpp | 1 + .../architecture/architecture_constraints.hpp | 9 +- .../regor/architecture/ethosu55/ethos_u55.cpp | 57 ++++++- .../regor/architecture/ethosu55/ethos_u55.hpp | 14 +- .../ethosu55/ethos_u55_constraints.cpp | 23 ++- .../ethosu55/ethos_u55_performance.cpp | 7 +- .../ethos_u55_register_cs_generator.cpp | 144 +++++++++++++++++- .../ethos_u55_register_cs_generator.hpp | 8 +- .../ethosu85/ethos_u85_constraints.cpp | 7 + ethosu/regor/common/box.hpp | 9 ++ ethosu/regor/common/shape.hpp | 15 +- .../high_level_command_stream_generator.cpp | 45 ++++-- ethosu/regor/compiler/scheduler.cpp | 73 +++++---- ethosu/regor/compiler/scheduler_decompose.cpp | 5 +- ethosu/regor/compiler/scheduler_operation.hpp | 13 ++ ethosu/regor/compiler/scheduler_packing.cpp | 1 + .../tflite/tflite_supported_operators_u55.cpp | 1 + 17 files changed, 353 insertions(+), 79 deletions(-) diff --git a/ethosu/regor/architecture/architecture.hpp b/ethosu/regor/architecture/architecture.hpp index 700f6840..993df631 100644 --- a/ethosu/regor/architecture/architecture.hpp +++ b/ethosu/regor/architecture/architecture.hpp @@ -284,6 +284,7 @@ struct ElementAccess int ofmWrite = 0; int weightsRefetch = 0; int constRead[2] = {0, 0}; + int tmpRead = 0, tmpWrite = 0; }; enum class MemChannel diff --git a/ethosu/regor/architecture/architecture_constraints.hpp b/ethosu/regor/architecture/architecture_constraints.hpp index 07db143f..e08b1c24 100644 --- a/ethosu/regor/architecture/architecture_constraints.hpp +++ b/ethosu/regor/architecture/architecture_constraints.hpp @@ -58,8 +58,9 @@ enum class ArchRequirement { None = 0, ScratchTensor = 1, - OutputFormat = 2, - OpSubstitution = 4, + OpSubstitution = 2, + OutputFormat = 4, + InputFormat = 8, }; struct ArchRequirements @@ -71,6 +72,8 @@ struct ArchRequirements DataType type = DataType::None; TensorFormat format = TensorFormat::Unknown; } scratch; + TensorFormat ifmFormat = TensorFormat::Unknown; + TensorFormat ifm1Format = TensorFormat::Unknown; TensorFormat ofmFormat = TensorFormat::Unknown; OpType substitution = OpType::None; }; @@ -95,8 +98,10 @@ enum class QueryResult Native = 2, Constrained = 4, HasRequirements = 8, + Decompose = 16, NativeHasReq = Native | HasRequirements, NativeConstrained = Native | Constrained, + NativeDecompose = Native | Decompose, NativeConstrainedHasReq = Native | Constrained | HasRequirements, }; diff --git a/ethosu/regor/architecture/ethosu55/ethos_u55.cpp b/ethosu/regor/architecture/ethosu55/ethos_u55.cpp index a37c15da..d03ff1a2 100644 --- a/ethosu/regor/architecture/ethosu55/ethos_u55.cpp +++ b/ethosu/regor/architecture/ethosu55/ethos_u55.cpp @@ -166,7 +166,8 @@ void ArchEthosU55::ApplyConfig(const AcceleratorConfig *cfg) // SHRAM layout information _shram.reservedOutputBanks = 2; - _shram.bankSizeBytes = 1024, _shram.totalBanks = cfg->shramBanks; + _shram.bankSizeBytes = 1024; + _shram.totalBanks = cfg->shramBanks; _shram.reservedEndBanks = (_shram.totalBanks > 16) ? 2 : 0; _shramMemory = std::make_unique("shram", _shram.bankSizeBytes * _shram.totalBanks); @@ -182,6 +183,33 @@ void ArchEthosU55::ApplyConfig(const AcceleratorConfig *cfg) std::unique_ptr ArchEthosU55::GetOpConfig(OpType opType, const ArchitectureConfigQuery &query) { + // Compound configuration: + if ( opType == OpType::MatMul ) + { + ArchitectureConfigQuery tmpQuery = query; + Kernel unitKernel = Kernel::UnitKernel(); + int batches = query.ofmShape.Height(); + // Block configuration for the Elementwise Mul + tmpQuery.kernel = &unitKernel; + tmpQuery.ifmBits = query.ifmBits; + tmpQuery.ifmShape[1] = Shape(1, batches, 1, query.ifmShape[1].Depth()); + tmpQuery.ofmShape = query.ifmShape[0]; + tmpQuery.ofmFormat = TensorFormat::NHWC; + tmpQuery.ofmBits = 32; + tmpQuery.transpose = TransposeType::None; + auto mulConfig = FindBlockConfig(OpType::Mul, tmpQuery); + // Block configuration for the Reduced Sum + tmpQuery.ofmShape = Shape(1, batches, query.ifmShape[0].Width(), 1); + tmpQuery.ofmBits = query.ofmBits; + tmpQuery.ofmFormat = query.ofmFormat; + auto reduceConfig = FindBlockConfig(OpType::ReduceSum, tmpQuery); + assert(mulConfig.get()); + assert(reduceConfig.get()); + reduceConfig->AttachPrevConfig(std::move(mulConfig)); + + return std::unique_ptr(reduceConfig.release()); + } + // Single op configurations auto config = FindBlockConfig(opType, query); return config; } @@ -285,7 +313,7 @@ static Shape FitBlockForOFM(const Shape &ofmShape, const Kernel *kernel, const S } -std::unique_ptr ArchEthosU55::FindBlockConfig(OpType opType, const ArchitectureConfigQuery &query) +std::unique_ptr ArchEthosU55::FindBlockConfig(OpType opType, const ArchitectureConfigQuery &query) { assert(query.ifmBits > 0 && query.ifmBits <= 32); assert(query.ofmShape.Size() > 2 && "Insufficient dimensions to search for block config"); @@ -298,6 +326,12 @@ std::unique_ptr ArchEthosU55::FindBlockConfig(OpType opTyp EthosU55NpuOp npuOp = GetHWOp(opType); assert(npuOp != EthosU55NpuOp::None); + if ( (npuOp == EthosU55NpuOp::Compound) && (opType == OpType::MatMul) ) + { + // The block config of the final output operator + npuOp = EthosU55NpuOp::ReduceSum; + opType = OpType::ReduceSum; + } // Figure out if SHRAM should be portioned for elementwise ElementwiseUsage ewUsage = ElementwiseUsage::No; @@ -493,11 +527,11 @@ std::unique_ptr ArchEthosU55::FindBlockConfig(OpType opTyp // Return the best configuration if ( bestCost != std::numeric_limits::infinity() ) { - return std::unique_ptr(config.release()); + return config; } // Didn't find a configuration - return std::unique_ptr(); + return {}; } @@ -591,6 +625,10 @@ std::unique_ptr EthosU55OpConfig::Clone() config->_ofmBlock = _ofmBlock; config->_ifmBlock = _ifmBlock; config->_layout = _layout; + if ( _prevConfig ) + { + config->_prevConfig.reset(static_cast(_prevConfig->Clone().release())); + } return std::unique_ptr(config.release()); } @@ -621,6 +659,17 @@ std::string EthosU55OpConfig::ToString(bool full) return tmp; } +void EthosU55OpConfig::AttachPrevConfig(std::unique_ptr prev) +{ + _prevConfig = std::move(prev); +} + +EthosU55OpConfig *EthosU55OpConfig::PrevConfig() +{ + return _prevConfig.get(); +} + + EthosU55NpuOp ArchEthosU55::GetHWOp(OpType type) { static const std::unordered_map toNpuOp = { diff --git a/ethosu/regor/architecture/ethosu55/ethos_u55.hpp b/ethosu/regor/architecture/ethosu55/ethos_u55.hpp index 7b87920a..f8e60644 100644 --- a/ethosu/regor/architecture/ethosu55/ethos_u55.hpp +++ b/ethosu/regor/architecture/ethosu55/ethos_u55.hpp @@ -33,7 +33,7 @@ namespace regor { -enum EthosU55SHRamElements +enum EthosU55SHRamElements : uint8_t { SHRAM_IFM8 = 0, SHRAM_IFM16 = 1, @@ -46,7 +46,7 @@ enum EthosU55SHRamElements SHRAM_Last = SHRAM_Acc40 }; -enum class EthosUTraversal +enum class EthosUTraversal : uint8_t { DepthFirst = 0, PartKernel = 1, @@ -91,10 +91,11 @@ private: SHRAMLayout _layout; Shape _ifmBlock; Shape _ofmBlock; + int _bankSize = 0; EthosU55SHRamElements _accumulatorType = SHRAM_Acc32; EthosUTraversal _traversal = EthosUTraversal::DepthFirst; - int _bankSize = 0; - int _ifmDepthBufScaling = 0; + int8_t _ifmDepthBufScaling = 0; + std::unique_ptr _prevConfig; public: EthosUTraversal Traversal() const { return _traversal; } @@ -107,6 +108,9 @@ public: Point2i OptimalStripeGranule() override; int OptimalDepthGranule() override; std::string ToString(bool full) override; + + void AttachPrevConfig(std::unique_ptr prev); + EthosU55OpConfig *PrevConfig(); }; /// @@ -216,7 +220,7 @@ protected: Shape OfmUBlock() { return _ofmUBlock; } void ApplyConfig(const AcceleratorConfig *cfg); - std::unique_ptr FindBlockConfig(OpType opType, const ArchitectureConfigQuery &query); + std::unique_ptr FindBlockConfig(OpType opType, const ArchitectureConfigQuery &query); bool TryBlockConfig(EthosU55OpConfig::SHRAMLayout &layout, int ewUsage, const Shape &ofmBlock, const Shape &ifmBlock, int ifmBits, int ifmGranule, int accBits, int accGranule, int lutBanks, int ifmDepthBufScaling); diff --git a/ethosu/regor/architecture/ethosu55/ethos_u55_constraints.cpp b/ethosu/regor/architecture/ethosu55/ethos_u55_constraints.cpp index 31649b43..7f2101c9 100644 --- a/ethosu/regor/architecture/ethosu55/ethos_u55_constraints.cpp +++ b/ethosu/regor/architecture/ethosu55/ethos_u55_constraints.cpp @@ -199,7 +199,12 @@ Flags EthosU55Constraints::OperatorQuery(OpType opType, const ArchO if ( query->transposeMask == TransposeType::NWHC || query->transposeMask == TransposeType::NHCW || query->transposeMask == TransposeType::NCWH ) { - if ( req ) req->ofmFormat = TensorFormat::NHWC; + if ( req ) + { + req->req.Set(ArchRequirement::OutputFormat, ArchRequirement::InputFormat); + req->ifmFormat = TensorFormat::NHWC; + req->ofmFormat = TensorFormat::NHWC; + } return QueryResult::NativeConstrainedHasReq; } } @@ -210,12 +215,18 @@ Flags EthosU55Constraints::OperatorQuery(OpType opType, const ArchO { if ( req ) { - req->req = ArchRequirement::ScratchTensor; - req->scratch.size = query->ofm.shape; - req->scratch.type = DataType::Int32; - req->scratch.format = TensorFormat::NHWC; + req->req.Set(ArchRequirement::ScratchTensor, ArchRequirement::OutputFormat, ArchRequirement::InputFormat); + if ( query->ifm[0].shape ) + { + req->scratch.size = query->ifm[0].shape.WithDepth(query->ifm[0].shape.Depth() + 1); + req->scratch.type = DataType::Int32; + req->scratch.format = TensorFormat::NHWC; + } + req->ifmFormat = TensorFormat::Unknown; + req->ifm1Format = TensorFormat::NHWC; // IFM1 and OFM are depth-sliced + req->ofmFormat = TensorFormat::NHWC; // and cannot be addressed if B16 } - return QueryResult::Unsupported; + return QueryResult::NativeHasReq; } else if ( (opType == OpType::Sigmoid) || (opType == OpType::Tanh) ) { diff --git a/ethosu/regor/architecture/ethosu55/ethos_u55_performance.cpp b/ethosu/regor/architecture/ethosu55/ethos_u55_performance.cpp index 9032c504..82aa1ddd 100644 --- a/ethosu/regor/architecture/ethosu55/ethos_u55_performance.cpp +++ b/ethosu/regor/architecture/ethosu55/ethos_u55_performance.cpp @@ -560,8 +560,6 @@ ElementAccess EthosU55Performance::MeasureElementAccess(const PerformanceQuery & { // IFM0 is read multiple times to cover all elements in ofmShape access.ifmRead[0] = Shape::RoundAway(query.ofmShape, ofmRounding).Elements(); - // Complete OFM is written - access.ofmWrite = access.ifmRead[0]; } else if ( query.type == OpType::Transpose ) { @@ -569,8 +567,11 @@ ElementAccess EthosU55Performance::MeasureElementAccess(const PerformanceQuery & } else if ( query.type == OpType::MatMul ) { - access.ifmRead[0] = query.ifmShape[0].Elements(); + // Requires pretransposed operand + int cols = query.ifmShape[1].Width(); + access.ifmRead[0] = query.ifmShape[0].Elements() * cols; access.ifmRead[1] = query.ifmShape[1].Elements(); + access.tmpRead = access.tmpWrite = access.ifmRead[0]; } else { diff --git a/ethosu/regor/architecture/ethosu55/ethos_u55_register_cs_generator.cpp b/ethosu/regor/architecture/ethosu55/ethos_u55_register_cs_generator.cpp index adb52f35..736b129e 100644 --- a/ethosu/regor/architecture/ethosu55/ethos_u55_register_cs_generator.cpp +++ b/ethosu/regor/architecture/ethosu55/ethos_u55_register_cs_generator.cpp @@ -942,7 +942,7 @@ void EthosU55RCSGenerator::GenerateOFMPrecision(const HLCFeatureMap &fm, bool us } // Generates common IFM registers -void EthosU55RCSGenerator::GenerateIFM(OpType opType, const HLCFeatureMap &fm, const Box &inputArea) +void EthosU55RCSGenerator::GenerateIFM(const HLCFeatureMap &fm, const Box &inputArea) { CheckAddresses(fm); Emit(isa::npu_set_ifm_region_t(ToRegion(fm.memArea))); @@ -970,7 +970,7 @@ void EthosU55RCSGenerator::GenerateIFM(OpType opType, const HLCFeatureMap &fm, c } // Generates common IFM2 registers -void EthosU55RCSGenerator::GenerateIFM2(OpType opType, const HLCFeatureMap &fm, const Box &inputArea, bool isScalar, int32_t scalarValue) +void EthosU55RCSGenerator::GenerateIFM2(const HLCFeatureMap &fm, const Box &inputArea, bool isScalar, int32_t scalarValue) { if ( isScalar ) { @@ -1003,7 +1003,7 @@ void EthosU55RCSGenerator::GenerateIFM2(OpType opType, const HLCFeatureMap &fm, } // Generates OFM registers -void EthosU55RCSGenerator::GenerateOFM(OpType opType, const HLCFeatureMap &fm, const Box &outputArea) +void EthosU55RCSGenerator::GenerateOFM(const HLCFeatureMap &fm, const Box &outputArea) { CheckAddresses(fm); Emit(isa::npu_set_ofm_region_t(ToRegion(fm.memArea))); @@ -1483,6 +1483,135 @@ void EthosU55RCSGenerator::InsertTransposeCommand(const HLCStripe *stripe, Tempo } } +namespace MatMul +{ +inline int Cols(const Shape &shape) +{ + return shape.Depth(); +} +inline int Rows(const Shape &shape) +{ + return shape.Width(); +} +inline int Batch(const Shape &shape) +{ + return shape.Height(); +} +} // namespace MatMul + +void EthosU55RCSGenerator::InsertMatMulCommand(const HLCStripe *stripe, Temporaries &temps, std::vector &emitted) +{ + auto op = stripe->operation.get(); + assert(op && op->ifm.size() > 2); + // Expect 3 inputs 2 IFM and one scratch tensor + if ( op->ifm.size() < 3 ) + { + return; + } + + HLCFeatureMap inFM0 = op->ifm[0]; + HLCFeatureMap inFM1 = op->ifm[1]; + HLCFeatureMap tempFM = op->ifm[2]; + HLCFeatureMap outFM = op->ofm; + + assert(op->subOps.empty()); + assert(tempFM.dataType == DataType::Int32); + assert(inFM1.format == TensorFormat::NHWC); + assert(outFM.format == TensorFormat::NHWC); + assert(tempFM.format == TensorFormat::NHWC); + + // Ensure shapes are in the form: 1, Batch=Height, Rows=Width, Cols=Depth + inFM0.shape = Shape::PadAxes(inFM0.shape, 4, 1); + inFM1.shape = Shape::PadAxes(inFM1.shape, 4, 1); + outFM.shape = Shape::PadAxes(outFM.shape, 4, 1); + + assert((!inFM0.slice.shape || (inFM0.shape.WC() == inFM0.slice.shape.WC())) && "Implementation cannot be sliced in depth"); + assert(MatMul::Cols(inFM0.shape) == MatMul::Cols(inFM1.shape) && (MatMul::Rows(inFM1.shape) == MatMul::Cols(outFM.shape)) && "Second ifm must be pre-transposed"); + + // Minimum required temporary space is one IFM0 W/C slice. + // Batches can be executed en-masse or as smaller H-slices (deduce from size of temporary space) + bool lowMemoryMode = (tempFM.shape.Elements() < inFM0.shape.Elements()) && (tempFM.shape.Elements() >= inFM0.shape.ElementsWC()); + assert(lowMemoryMode || (tempFM.shape.Elements() >= inFM0.shape.Elements())); + + // Execute batches individually if required + int maxSteps = MatMul::Cols(outFM.shape); + int batchLoops = 1; + int batches = MatMul::Batch(outFM.shape); + if ( lowMemoryMode ) + { + std::swap(batches, batchLoops); + } + // Broadcast H/C slices + Shape ifm0Shape = Shape(1, batches, MatMul::Rows(inFM0.shape), MatMul::Cols(inFM0.shape)); + Shape ifm1Shape = Shape(1, batches, 1, MatMul::Cols(inFM1.shape)); + + // Temporary unquantized tensor for MUL result + tempFM.slice.offset = ifm0Shape.WithZeros(); + tempFM.slice.shape = ifm0Shape; + tempFM.strides = Shape::GetStridesForShape(tempFM.shape, Shape(sizeof(uint32_t))); + tempFM.quantization = Quantization::Unit(); + assert(tempFM.usage == TensorUsage::Scratch); + + // Final output tensor slice sizes + Shape ofmShape = Shape(1, batches, MatMul::Rows(ifm0Shape), 1); + EthosU55OpConfig *reduceConfig = static_cast(stripe->operation->config); + EthosU55OpConfig *mulConfig = reduceConfig->PrevConfig(); + assert(reduceConfig && mulConfig); + + // Push quantisation on to the last operation + QuantizedScale qs0 = inFM0.quantization.scales.empty() ? QuantizedScale::Unit() : inFM0.quantization.scales[0]; + QuantizedScale qs1 = inFM1.quantization.scales.empty() ? QuantizedScale::Unit() : inFM1.quantization.scales[0]; + QuantizedScale qOfm = outFM.quantization.scales.empty() ? QuantizedScale::Unit() : outFM.quantization.scales[0]; + inFM0.quantization.scales.clear(); + inFM1.quantization.scales.clear(); + outFM.quantization.scales.clear(); + + double scaling = (qs0.Dequantize() * qs1.Dequantize()) / qOfm.Dequantize(); + outFM.quantization.type = QuantizationType::EXPLICIT; + outFM.quantization.scales.push_back(QuantizedScale(scaling)); + + for ( int batch = 0; batch < batchLoops; batch++ ) + { + Shape ifm1Start(0, batch, 0, 0); + Shape ofmStart(0, batch, 0, 0); + for ( int step = 0; step < maxSteps; step++ ) + { + // Step 1: MUL: IFM0 x IFM1 -> TEMP BUFFER + // Create Multiply stripe operation + auto mul = std::make_unique(std::make_shared()); + mul->operation->type = OpType::Mul; + mul->operation->kernel = Kernel::UnitKernel(); + mul->operation->ifm.push_back(inFM0); + mul->operation->ifm.push_back(inFM1); + mul->operation->ofm = tempFM; + mul->operation->config = mulConfig; + mul->ofmArea = ifm0Shape; + mul->ifmAreas.emplace_back(ifm0Shape); + mul->ifmAreas.emplace_back(ifm1Start, Box::Size(ifm1Shape)); + mul->opGroup = nullptr; + emitted.push_back(mul.get()); + temps.cmds.push_back(std::move(mul)); + + // Step 2: REDUCE SUM: TEMP BUFFER -> OFM + // Create Reduce sum stripe operation + auto sum = std::make_unique(std::make_shared()); + sum->operation->type = OpType::ReduceSum; + sum->operation->kernel = Kernel::UnitKernel(); + sum->operation->ifm.push_back(tempFM); + sum->operation->ofm = outFM; + sum->operation->config = reduceConfig; + sum->ofmArea = Box(ofmStart, Box::Size(ofmShape)); + sum->ifmAreas.emplace_back(tempFM.slice.shape); + sum->opGroup = nullptr; + emitted.push_back(sum.get()); + temps.cmds.push_back(std::move(sum)); + + // Move to next input offset and output slice + ifm1Start[-2] += 1; + ofmStart[-1] += 1; + } + } +} //---------------------------------------------------------------------- // Operations @@ -1544,7 +1673,7 @@ void EthosU55RCSGenerator::GenerateCommon(const HLCStripe *stripe, bool useGloba MemoryAccesses &memoryAccesses, int ifm0Index) { auto op = stripe->operation.get(); - GenerateIFM(op->type, op->ifm[ifm0Index], stripe->ifmAreas[ifm0Index]); + GenerateIFM(op->ifm[ifm0Index], stripe->ifmAreas[ifm0Index]); memoryAccesses.push_back(ToMemoryAccess(op->ifm[ifm0Index], stripe->ifmAreas[ifm0Index], AccessDirection::Read)); // Select rounding based on RCSIfmScaleMode @@ -1563,7 +1692,7 @@ void EthosU55RCSGenerator::GenerateCommon(const HLCStripe *stripe, bool useGloba { GeneratePadding(stripe->padding); } - GenerateOFM(op->type, op->ofm, stripe->ofmArea); + GenerateOFM(op->ofm, stripe->ofmArea); memoryAccesses.push_back(ToMemoryAccess(op->ofm, stripe->ofmArea, AccessDirection::Write)); GenerateOFMPrecision(op->ofm, useGlobalScale); EthosU55OpConfig *config = static_cast(stripe->operation->config); @@ -1667,7 +1796,7 @@ void EthosU55RCSGenerator::GenerateElementwiseOp(const HLCStripe *stripe, Memory assert(size_t(ifm2Index) < stripe->ifmAreas.size()); const HLCFeatureMap &ifm2 = op->ifm.at(ifm2Index); bool isScalar = IsScalar(ifm2, scalarValue); - GenerateIFM2(opType, ifm2, stripe->ifmAreas[ifm2Index], isScalar, scalarValue); + GenerateIFM2(ifm2, stripe->ifmAreas[ifm2Index], isScalar, scalarValue); if ( !isScalar ) { memoryAccesses.push_back(ToMemoryAccess(ifm2, stripe->ifmAreas[ifm2Index], AccessDirection::Read)); @@ -1768,7 +1897,8 @@ void EthosU55RCSGenerator::PrepareCommand(int index, HighLevelCommand *cmd, Temp } else if ( op->type == OpType::MatMul ) { - return; // Delete until implemented + InsertMatMulCommand(stripe, temps, emitted); + return; } else if ( _arch->_shram.reservedEndBanks == 0 ) { diff --git a/ethosu/regor/architecture/ethosu55/ethos_u55_register_cs_generator.hpp b/ethosu/regor/architecture/ethosu55/ethos_u55_register_cs_generator.hpp index e6a33820..233b90ba 100644 --- a/ethosu/regor/architecture/ethosu55/ethos_u55_register_cs_generator.hpp +++ b/ethosu/regor/architecture/ethosu55/ethos_u55_register_cs_generator.hpp @@ -210,11 +210,11 @@ protected: // Generates OFM_PRECISION register void GenerateOFMPrecision(const HLCFeatureMap &fm, bool useGlobalScale); // Generates common IFM registers - void GenerateIFM(OpType opType, const HLCFeatureMap &fm, const Box &inputArea); + void GenerateIFM(const HLCFeatureMap &fm, const Box &inputArea); // Generates common IFM2 registers - void GenerateIFM2(OpType opType, const HLCFeatureMap &fm, const Box &inputArea, bool isScalar, int32_t scalarValue); + void GenerateIFM2(const HLCFeatureMap &fm, const Box &inputArea, bool isScalar, int32_t scalarValue); // Generates OFM registers - void GenerateOFM(OpType opType, const HLCFeatureMap &fm, const Box &outputArea); + void GenerateOFM(const HLCFeatureMap &fm, const Box &outputArea); // Generates WEIGHT registers void GenerateWeights(const HLCStripe *stripe, MemoryAccesses &memoryAccesses); // Generates SCALE registers @@ -240,6 +240,8 @@ protected: virtual void InsertTileDMACommand(const HLCStripe *stripe, Temporaries &temps, std::vector &emitted); // Inserts commands to handle transposing virtual void InsertTransposeCommand(const HLCStripe *stripe, Temporaries &temps, std::vector &emitted); + // Inserts commands to handle MATMUL operations + void InsertMatMulCommand(const HLCStripe *stripe, Temporaries &temps, std::vector &emitted); //---------------------------------------------------------------------- // Operations diff --git a/ethosu/regor/architecture/ethosu85/ethos_u85_constraints.cpp b/ethosu/regor/architecture/ethosu85/ethos_u85_constraints.cpp index 59ec0938..338a65b3 100644 --- a/ethosu/regor/architecture/ethosu85/ethos_u85_constraints.cpp +++ b/ethosu/regor/architecture/ethosu85/ethos_u85_constraints.cpp @@ -209,6 +209,13 @@ Flags EthosU85Constraints::OperatorQuery(OpType opType, const ArchO } return QueryResult::NativeHasReq; } + else if ( opType == OpType::MatMul ) + { + if ( (query->ofm.shape.Size() >= 2) && query->ofm.shape.Elements() > query->ofm.shape.ElementsWC() ) + { + return QueryResult::NativeDecompose; + } + } return QueryResult::Native; } diff --git a/ethosu/regor/common/box.hpp b/ethosu/regor/common/box.hpp index 2e6eb729..84377977 100644 --- a/ethosu/regor/common/box.hpp +++ b/ethosu/regor/common/box.hpp @@ -30,6 +30,13 @@ private: Shape _start; Shape _end; +public: + struct Size + { + const Shape &_size; + Size(const Shape &size) : _size(size){}; + }; + public: Box() = default; @@ -39,6 +46,8 @@ public: assert(start <= end); } + Box(const Shape &start, const Box::Size &size) : _start(start), _end(start + size._size) {} + Box(const Shape &end) : Box(end.WithZeros(), end) {} Shape &Start() { return _start; } diff --git a/ethosu/regor/common/shape.hpp b/ethosu/regor/common/shape.hpp index 71a248f8..0c08229c 100644 --- a/ethosu/regor/common/shape.hpp +++ b/ethosu/regor/common/shape.hpp @@ -465,6 +465,12 @@ public: return Point2(TYPE(Width()), TYPE(Depth())); } + template + Point2 WC(TYPE pad) const + { + return Point2((_last > 0) ? TYPE(At(1)) : pad, (_last < 0) ? pad : TYPE(At(0))); + } + template Point2 WH() const { @@ -486,6 +492,13 @@ public: return int(result); } + int ElementsWC() const + { + int64_t result = int64_t(Width()) * Depth(); + assert(result <= std::numeric_limits::max()); + return int(result); + } + int Elements() const { int64_t result = Elements64(); @@ -768,7 +781,7 @@ public: static Shape PadAxes(const Shape &shape, int axes, int padValue) { - if ( shape.Size() == axes ) return shape; + if ( shape.Size() >= axes ) return shape; return Shape(shape, std::max(axes, shape.Size()), padValue); } diff --git a/ethosu/regor/compiler/high_level_command_stream_generator.cpp b/ethosu/regor/compiler/high_level_command_stream_generator.cpp index d5c5a2ad..21d81c0a 100644 --- a/ethosu/regor/compiler/high_level_command_stream_generator.cpp +++ b/ethosu/regor/compiler/high_level_command_stream_generator.cpp @@ -254,7 +254,7 @@ static void MakeFeatureMap(TensorUsage usage, const SchedulerConnection *schedCo fm.reverse = schedConn->reverse; fm.resamplingMode = schedConn->resamplingMode; fm.rounding = HLCRoundMode(schedConn->rounding); - fm.uid = schedConn->tensor->uid; + fm.uid = schedTens->uid; } static std::unique_ptr MakeWeights(NpuWeightTensor *srcTensor, Buffering buffering, SchedulerTensor *bufTensor = nullptr) @@ -284,18 +284,30 @@ static HLCSubOperation MakeSubOperation(const std::unique_ptrType(); auto lutConn = schedOp->TryInput(TensorUsage::LUT); - + size_t ifms = 0; for ( const auto &input : schedOp->inputs.pairs() ) { - std::vector::iterator at; - if ( IsIFM(input.first) ) + if ( IsIFM(input.first) || GetUsageType(input.first) == TensorUsage::Scratch ) { - at = hlcSubOp.ifm.emplace(std::upper_bound(hlcSubOp.ifm.begin(), hlcSubOp.ifm.end(), input.first, - [](TensorUsage usage, const HLCFeatureMap &fm) { return usage < fm.usage; })); + std::vector::iterator at; + if ( IsIFM(input.first) ) + { + // Insert IFMs, into the IFM section [0..ifms) sorted into order. + at = hlcSubOp.ifm.emplace(std::upper_bound(hlcSubOp.ifm.begin(), + hlcSubOp.ifm.begin() + std::min(ifms, hlcSubOp.ifm.size()), input.first, + [](TensorUsage usage, const HLCFeatureMap &fm) { return usage < fm.usage; })); + ifms++; // Increase size of IFM section + } + else + { + // Non-IFM tensors get appended + at = hlcSubOp.ifm.emplace(hlcSubOp.ifm.end()); + } MakeFeatureMap(input.first, &input.second, *at); } } MakeFeatureMap(TensorUsage::OFM, schedOp->OFM(), hlcSubOp.ofm); + hlcSubOp._srcId = schedOp->Uid(); if ( schedOp->Type() == OpType::LeakyRelu ) @@ -323,18 +335,29 @@ static std::shared_ptr MakeOperation(SchedulerOperation *schedOp, op->kernel = *schedOp->Kernel(); op->config = opInfo->Config(); op->_srcId = schedOp->Uid(); - + size_t ifms = 0; for ( const auto &input : schedOp->inputs.pairs() ) { - std::vector::iterator at; - if ( IsIFM(input.first) ) + if ( IsIFM(input.first) || GetUsageType(input.first) == TensorUsage::Scratch ) { - at = op->ifm.emplace(std::upper_bound(op->ifm.begin(), op->ifm.end(), input.first, - [](TensorUsage usage, const HLCFeatureMap &fm) { return usage < fm.usage; })); + std::vector::iterator at; + if ( IsIFM(input.first) ) + { + // Insert IFMs, into the IFM section [0..ifms) sorted into order. + at = op->ifm.emplace(std::upper_bound(op->ifm.begin(), op->ifm.begin() + std::min(ifms, op->ifm.size()), + input.first, [](TensorUsage usage, const HLCFeatureMap &fm) { return usage < fm.usage; })); + ifms++; // Increase size of IFM section + } + else + { + // Non-IFM tensors get appended + at = op->ifm.emplace(op->ifm.end()); + } MakeFeatureMap(input.first, &input.second, *at); } } MakeFeatureMap(TensorUsage::OFM, schedOp->OFM(), op->ofm); + #ifndef NDEBUG op->name = schedOp->OFM()->tensor->Name(); #endif diff --git a/ethosu/regor/compiler/scheduler.cpp b/ethosu/regor/compiler/scheduler.cpp index 6c74bf3f..06d007a9 100644 --- a/ethosu/regor/compiler/scheduler.cpp +++ b/ethosu/regor/compiler/scheduler.cpp @@ -28,6 +28,7 @@ #include "common/vector_span.hpp" #include "faststorage_allocator.hpp" #include "live_range.hpp" +#include "scheduler_decompose.hpp" #include "tensor_allocator.hpp" #include @@ -230,67 +231,90 @@ int Scheduler::UpdateSchedulerTensor(TensorUsage usage, SchedulerConnection *con for ( auto producer : tensor->producers ) { + if ( producer->IsNpuOp() ) + { + tensor->hasNPUWriters = true; + } + else + { + tensor->hasCPUWriters = true; + } + // TODO: Gather doesn't support brick format yet (MLBEDSW-8410) if ( producer->Type() == OpType::Scatter || producer->Type() == OpType::Gather ) { tensor->needsLinearFormat = true; + continue; } // TODO: Tile doesn't support brick format yet (MLBEDSW-9485) else if ( producer->Type() == OpType::Tile ) { tensor->needsLinearFormat = true; + continue; } - else if ( producer->Type() == OpType::Transpose ) + else { ArchRequirements req; ArchOperatorQuery query; query.transposeMask = producer->OFM()->transpose; - if ( _arch->Constraints()->OperatorQuery(OpType::Transpose, &query, &req).Any(QueryResult::Native) ) + if ( _arch->Constraints()->OperatorQuery(producer->Type(), &query, &req).Any(QueryResult::Native) ) { - if ( req.ofmFormat == TensorFormat::NHWC ) + if ( (req.req % ArchRequirement::OutputFormat) && req.ofmFormat == TensorFormat::NHWC ) { tensor->needsLinearFormat = true; + continue; } } } + } - if ( producer->IsNpuOp() ) + for ( auto consumer : tensor->consumers ) + { + if ( consumer->IsNpuOp() ) { - tensor->hasNPUWriters = true; + tensor->hasNPUReaders = true; } else { - tensor->hasCPUWriters = true; + tensor->hasCPUReaders = true; } - } - for ( auto consumer : tensor->consumers ) - { // TODO: Gather doesn't support brick format yet (MLBEDSW-8410) if ( consumer->Type() == OpType::Scatter || consumer->Type() == OpType::Gather ) { tensor->needsLinearFormat = true; + continue; } // TODO: Tile doesn't support brick format yet (MLBEDSW-9485) else if ( consumer->Type() == OpType::Tile ) { tensor->needsLinearFormat = true; + continue; } // Int32 ReduceSum requires linear format else if ( consumer->Type() == OpType::ReduceSum && tensor->dataType == DataType::Int32 ) { tensor->needsLinearFormat = true; + continue; } - else if ( consumer->Type() == OpType::Transpose ) + + TensorUsage usedAs = TensorUsage::None; + auto tensorConn = consumer->HasInput(tensor, usedAs); + + ArchRequirements req; + ArchOperatorQuery query; + Set(query.ifm[0], consumer->TryIFM(0)); + Set(query.ifm[1], consumer->TryIFM(1)); + query.transposeMask = consumer->OFM()->transpose; + if ( _arch->Constraints()->OperatorQuery(consumer->Type(), &query, &req).Any(QueryResult::Native) ) { - ArchRequirements req; - ArchOperatorQuery query; - query.transposeMask = consumer->OFM()->transpose; - if ( _arch->Constraints()->OperatorQuery(OpType::Transpose, &query, &req).Any(QueryResult::Native) ) + if ( (req.req % ArchRequirement::InputFormat) ) { - if ( req.ofmFormat == TensorFormat::NHWC ) + if ( (usedAs == TensorUsage::IFM0 && req.ifmFormat == TensorFormat::NHWC) || + (usedAs == TensorUsage::IFM1 && req.ifm1Format == TensorFormat::NHWC) ) { tensor->needsLinearFormat = true; + continue; } } } @@ -298,27 +322,10 @@ int Scheduler::UpdateSchedulerTensor(TensorUsage usage, SchedulerConnection *con // Check if consumer shape requires linear format // Brick format can only be used if both shapes have equal W and C // Need to check full shape on connection since tensor might have many producers (concat) - auto ifm0 = consumer->TryIFM(0); - auto ifm1 = consumer->TryIFM(1); - auto ifm2 = consumer->TryIFM(2); - if ( (ifm0 && ifm0->tensor.get() == tensor && ifm0->SliceShape() && conn->shape && - Shape::PadAxes(ifm0->SliceShape(), 2, 1).WC() != Shape::PadAxes(conn->shape, 2, 1).WC()) || - (ifm1 && ifm1->tensor.get() == tensor && ifm1->SliceShape() && conn->shape && - Shape::PadAxes(ifm1->SliceShape(), 2, 1).WC() != Shape::PadAxes(conn->shape, 2, 1).WC()) || - (ifm2 && ifm2->tensor.get() == tensor && ifm2->SliceShape() && conn->shape && - Shape::PadAxes(ifm2->SliceShape(), 2, 1).WC() != Shape::PadAxes(conn->shape, 2, 1).WC()) ) + if ( IsIFM(usedAs) && conn->shape && tensorConn && (tensorConn->SliceShape().WC(1) != conn->shape.WC(1)) ) { tensor->needsLinearFormat = true; } - - if ( consumer->IsNpuOp() ) - { - tensor->hasNPUReaders = true; - } - else - { - tensor->hasCPUReaders = true; - } } // Initial criteria (may change) diff --git a/ethosu/regor/compiler/scheduler_decompose.cpp b/ethosu/regor/compiler/scheduler_decompose.cpp index 079509f3..81234b4a 100644 --- a/ethosu/regor/compiler/scheduler_decompose.cpp +++ b/ethosu/regor/compiler/scheduler_decompose.cpp @@ -204,16 +204,13 @@ bool CanRunOnHardware(Architecture *arch, const SchedulerOperation *schedOp) } if ( schedOp->Type() == OpType::MatMul ) { - auto &ofmShape = schedOp->OFM()->SliceShape(); - if ( ofmShape.Size() > 2 && ofmShape.Elements() > ofmShape.Width() * ofmShape.Depth() ) return false; - const auto ofmConn = schedOp->OFM(); ArchOperatorQuery query; Set(query.ifm[0], schedOp->IFM(0)); Set(query.ifm[1], schedOp->IFM(1)); Set(query.ofm, ofmConn); query.transposeMask = ofmConn->transpose; - if ( !arch->Constraints()->OperatorQuery(OpType::MatMul, &query, nullptr).Any(QueryResult::Native) ) + if ( (arch->Constraints()->OperatorQuery(OpType::MatMul, &query, nullptr) & QueryResult::NativeDecompose) != QueryResult::Native ) { return false; } diff --git a/ethosu/regor/compiler/scheduler_operation.hpp b/ethosu/regor/compiler/scheduler_operation.hpp index f89f38b1..4cb6456b 100644 --- a/ethosu/regor/compiler/scheduler_operation.hpp +++ b/ethosu/regor/compiler/scheduler_operation.hpp @@ -244,6 +244,19 @@ public: SchedulerConnection *IFM(int index) { return &inputs.at(MakeTensorUsage(TensorUsage::IFM, index)); } const SchedulerConnection *IFM(int index) const { return &inputs.at(MakeTensorUsage(TensorUsage::IFM, index)); } + SchedulerConnection *HasInput(const SchedulerTensor *tensor, TensorUsage &as) + { + for ( const auto &pair : inputs.pairs() ) + { + if ( pair.second.tensor.get() == tensor ) + { + as = pair.first; + return &pair.second; + } + } + return nullptr; + } + // Invalidates all pointers to input connections. void RemoveInput(TensorUsage usage) { diff --git a/ethosu/regor/compiler/scheduler_packing.cpp b/ethosu/regor/compiler/scheduler_packing.cpp index f94e330f..ce79d63f 100644 --- a/ethosu/regor/compiler/scheduler_packing.cpp +++ b/ethosu/regor/compiler/scheduler_packing.cpp @@ -548,6 +548,7 @@ std::unique_ptr SchedulerPacking::MakeSchedulerOperation(Ope auto scratchTensor = std::make_shared(req.scratch.type, req.scratch.size, req.scratch.format); SchedulerConnection *scratchConn = schedOp->AddInput(TensorUsage::Scratch0, scratchTensor); scratchConn->shape = req.scratch.size; + scratchTensor->memArea = _arch->StagingMemory(); } } diff --git a/ethosu/regor/tflite/tflite_supported_operators_u55.cpp b/ethosu/regor/tflite/tflite_supported_operators_u55.cpp index cfedfb3b..837fdfd1 100644 --- a/ethosu/regor/tflite/tflite_supported_operators_u55.cpp +++ b/ethosu/regor/tflite/tflite_supported_operators_u55.cpp @@ -33,6 +33,7 @@ TfLiteSupportedOperatorsU55::TfLiteSupportedOperatorsU55(IArchitectureConstraint // clang-format off OpType::Add, OpType::AvgPool, + OpType::BatchMatMul, OpType::Concat, OpType::Conv2D, OpType::DepthwiseConv2D, -- GitLab