diff --git a/ethosu/regor/architecture/architecture_constraints.hpp b/ethosu/regor/architecture/architecture_constraints.hpp index 70931bcf8ae66995df94312dd8089489adc650f3..72c20e6682804d0d0c26a639724f91b61d9cffeb 100644 --- a/ethosu/regor/architecture/architecture_constraints.hpp +++ b/ethosu/regor/architecture/architecture_constraints.hpp @@ -79,10 +79,13 @@ struct ExecutionQuery bool quantScalingInvalidOrUnequal = false; }; -namespace Constraints +enum class TransposeSupport { - -} // namespace Constraints + None, + NHWC = 1, + NHCWB16 = 2, + Any = NHWC | NHCWB16, +}; /// /// Architecture capabilties query @@ -96,7 +99,7 @@ public: virtual bool SupportsFusedRescale(OpType opType, TensorUsage tensorUsage, DataType fromType, DataType toType, const Quantization &quantization) = 0; virtual bool SupportsRescale(DataType fromType, DataType toType) = 0; - virtual bool SupportsTranspose(OpType opType, TransposeType transposeType) = 0; + virtual TransposeSupport SupportsTranspose(OpType opType, TransposeType transposeType) = 0; virtual bool SupportsAccumulatorSaveRestore() = 0; bool CanExecute(const ExecutionQuery &query) diff --git a/ethosu/regor/architecture/ethosu55/ethos_u55.cpp b/ethosu/regor/architecture/ethosu55/ethos_u55.cpp index 4abfa22f2c595b4a06d43370430e87cd6e25f786..15a7ba8ccf9668173d0eb2916d6f7be889fdda7a 100644 --- a/ethosu/regor/architecture/ethosu55/ethos_u55.cpp +++ b/ethosu/regor/architecture/ethosu55/ethos_u55.cpp @@ -102,7 +102,6 @@ const static Shape MAX_SHAPE(nullptr, 8, 65536); ArchEthosU55::ArchEthosU55() : _subkernelMax(8, 8, 65536), _ofmBlockMax(32, 64, 128) { _weightEncoder = std::make_unique(this); - _rcsGenerator = std::make_unique(this); _constraints = std::make_unique(this); } @@ -175,6 +174,7 @@ void ArchEthosU55::ApplyConfig(const AcceleratorConfig *cfg) _shramMemory->SetParameters(1, 0, 0, 1, 1, 1000, 1000); _lutMemory = _shramMemory.get(); _performance = std::unique_ptr(new EthosU55Performance(this, cfg->perfInfo)); + _rcsGenerator = std::make_unique(this); } @@ -633,6 +633,7 @@ EthosU55NpuOp ArchEthosU55::GetHWOp(OpType type) {OpType::ReduceSum, EthosU55NpuOp::ReduceSum}, {OpType::Rescale, EthosU55NpuOp::Pooling}, {OpType::Tile, EthosU55NpuOp::Dma}, + {OpType::Transpose, EthosU55NpuOp::Compound}, }; auto pos = toNpuOp.find(type); if ( pos != toNpuOp.end() ) @@ -763,7 +764,7 @@ bool EthosU55OpGroup::CanRunOnNPU(const ArchitectureOpGroupQuery &op) return false; } - if ( npuOp == EthosU55NpuOp::None ) + if ( npuOp == EthosU55NpuOp::None || npuOp > EthosU55NpuOp::Compound ) { return false; } @@ -784,21 +785,6 @@ bool EthosU55OpGroup::CanRunOnNPU(const ArchitectureOpGroupQuery &op) return false; } - switch ( npuOp ) - { - case EthosU55NpuOp::Convolution: - case EthosU55NpuOp::Depthwise: - case EthosU55NpuOp::VectorProduct: - case EthosU55NpuOp::Pooling: - case EthosU55NpuOp::ReduceSum: - case EthosU55NpuOp::Elementwise: - case EthosU55NpuOp::Dma: - break; - default: - assert(false && "Unrecognized HWOp"); - return false; - } - // Validate that input/outputs shapes don't overflow if ( npuOp != EthosU55NpuOp::Dma ) { @@ -825,7 +811,8 @@ bool EthosU55OpGroup::CanRunOnNPU(const ArchitectureOpGroupQuery &op) // Check allowed ifm/ofm type mapping if ( npuOp != EthosU55NpuOp::Elementwise ) { - if ( op.type == OpType::LUT || op.type == OpType::MemoryCopy || op.type == OpType::Rescale || op.type == OpType::Tile ) + if ( op.type == OpType::LUT || op.type == OpType::MemoryCopy || op.type == OpType::Rescale || + op.type == OpType::Tile || op.type == OpType::Transpose ) { // TODO: LUT operations end up here due to UseAvgPoolNop although the rules are not the same as // for a Pooling operation, so skip checks for now. return true; diff --git a/ethosu/regor/architecture/ethosu55/ethos_u55.hpp b/ethosu/regor/architecture/ethosu55/ethos_u55.hpp index f2723918b2188302731696a49de8a43b5e5fd3fc..fd39da339a23f09d59815d5851767f2fd5a2421c 100644 --- a/ethosu/regor/architecture/ethosu55/ethos_u55.hpp +++ b/ethosu/regor/architecture/ethosu55/ethos_u55.hpp @@ -65,6 +65,8 @@ enum class EthosU55NpuOp ReduceSum, Elementwise, Dma, + Compound, + Last = Compound, }; /// diff --git a/ethosu/regor/architecture/ethosu55/ethos_u55_constraints.cpp b/ethosu/regor/architecture/ethosu55/ethos_u55_constraints.cpp index 7bdea0d14265d8c3e0e318ba9c42adc0f16abee0..abfac3100bb6c1084cdecc442394128f622a04be 100644 --- a/ethosu/regor/architecture/ethosu55/ethos_u55_constraints.cpp +++ b/ethosu/regor/architecture/ethosu55/ethos_u55_constraints.cpp @@ -34,10 +34,10 @@ bool EthosU55Constraints::SupportsMatMul(OpType opType) return false; } -bool EthosU55Constraints::SupportsTranspose(OpType opType, TransposeType transposeType) +TransposeSupport EthosU55Constraints::SupportsTranspose(OpType opType, TransposeType transposeType) { - UNUSED(opType); - return IsNone(transposeType); + if ( IsNone(transposeType) ) return TransposeSupport::Any; + return TransposeSupport::None; } bool EthosU55Constraints::SupportsReverse(OpType opType, ReverseType reverseTypeMask) diff --git a/ethosu/regor/architecture/ethosu55/ethos_u55_constraints.hpp b/ethosu/regor/architecture/ethosu55/ethos_u55_constraints.hpp index b12223c79c6c970611a50486487e31b6cde08aa6..e2c16f0ee17f2475d9eb1fe1d7018eba1febfe35 100644 --- a/ethosu/regor/architecture/ethosu55/ethos_u55_constraints.hpp +++ b/ethosu/regor/architecture/ethosu55/ethos_u55_constraints.hpp @@ -29,7 +29,7 @@ public: bool SupportsLeakyRelu(bool quantized, DataType type) override; bool SupportsMatMul(OpType opType) override; - bool SupportsTranspose(OpType opType, TransposeType transposeType) override; + TransposeSupport SupportsTranspose(OpType opType, TransposeType transposeType) override; bool SupportsReverse(OpType opType, ReverseType reverseTypeMask) override; bool SupportsFusedRescale(OpType opType, TensorUsage tensorUsage, DataType fromType, DataType toType, const Quantization &quantization) override; bool SupportsRescale(DataType fromType, DataType toType) override; diff --git a/ethosu/regor/architecture/ethosu55/ethos_u55_performance.cpp b/ethosu/regor/architecture/ethosu55/ethos_u55_performance.cpp index 8f06b1573e0853d88317d0d6994b23017f9189ba..eaebcba32c0ef490a357a10d5585052e6d6574bc 100644 --- a/ethosu/regor/architecture/ethosu55/ethos_u55_performance.cpp +++ b/ethosu/regor/architecture/ethosu55/ethos_u55_performance.cpp @@ -26,7 +26,7 @@ namespace regor { -static const Point2i s_SubkernelLimits[] = { +static const Point2i s_SubkernelLimits[size_t(EthosU55NpuOp::Last) + 1] = { {0, 0}, // No kernel {8, 8}, // Convolution {8, 8}, // Depthwise @@ -35,11 +35,12 @@ static const Point2i s_SubkernelLimits[] = { {8, 8}, // ReduceSum {1, 1}, // Elementwise {1, 1}, // Dma + {0, 0}, // Compound }; static constexpr bool OpUsesMacs(EthosU55NpuOp npuOp) { - return (npuOp != EthosU55NpuOp::Elementwise && npuOp != EthosU55NpuOp::Dma && npuOp != EthosU55NpuOp::None); + return (npuOp >= EthosU55NpuOp::Convolution) && (npuOp <= EthosU55NpuOp::ReduceSum); } EthosU55Performance::EthosU55Performance(ArchEthosU55 *arch, const EthosU55PerfInfo *perfInfo) : _arch(arch) @@ -83,6 +84,12 @@ CycleCost EthosU55Performance::MeasureCycleCost(const PerformanceQuery &query, c // TODO: MLBEDSW-8400 cycles.opCycles = 0; } + else if ( npuOp == EthosU55NpuOp::Compound ) + { + // TODO: Measure variable-implementation ops + assert(query.type == OpType::Transpose); + cycles.opCycles = EstimateMinimumMemoryCycles(query); + } else { assert(false && "Unknown operator cycle costing"); @@ -504,9 +511,14 @@ ElementAccess EthosU55Performance::MeasureElementAccess(const PerformanceQuery & else if ( query.type == OpType::Tile ) { // IFM0 is read multiple times to cover all elements in ofmShape - access.ifmRead[0] = Shape::RoundAway(query.ofmShape[0], ofmRounding).Elements(); + access.ifmRead[0] = Shape::RoundAway(query.ofmShape, ofmRounding).Elements(); // Complete OFM is written - access.ofmWrite = Shape::RoundAway(query.ofmShape[0], ofmRounding).Elements(); + access.ofmWrite = access.ifmRead[0]; + } + else if ( query.type == OpType::Transpose ) + { + access.ifmRead[0] = query.ifmShape[0].Elements(); + access.ofmWrite = query.ofmShape.Elements(); } else { diff --git a/ethosu/regor/architecture/ethosu55/ethos_u55_register_cs_generator.cpp b/ethosu/regor/architecture/ethosu55/ethos_u55_register_cs_generator.cpp index dbb4b37a3cae952235410f16582bd33dcb9f3bd9..39da2358c17e1ab91935db94c28b4877b8fb18e7 100644 --- a/ethosu/regor/architecture/ethosu55/ethos_u55_register_cs_generator.cpp +++ b/ethosu/regor/architecture/ethosu55/ethos_u55_register_cs_generator.cpp @@ -240,6 +240,9 @@ bool EthosU55RCSGenerator::IsSupportedElementwise(const OpType opType) EthosU55RCSGenerator::EthosU55RCSGenerator(ArchEthosU55 *arch) : _arch(arch) { + int slots = (_arch->_shram.bankSizeBytes * _arch->_shram.lutBanks) / _arch->_shram.lutSlotSize; + assert(slots); + _lutSlots.resize(slots); } @@ -666,7 +669,7 @@ void EthosU55RCSGenerator::GetJobs(const Box &area, const Shape &jobShape, int n } // Calculates the value for the BLOCKDEP register -int EthosU55RCSGenerator::CalcBlockDep(HLCStripe *prevStripe, HLCStripe *stripe) +int EthosU55RCSGenerator::CalcBlockDep(const HLCStripe *prevStripe, const HLCStripe *stripe) { if ( prevStripe == nullptr ) { @@ -686,6 +689,7 @@ int EthosU55RCSGenerator::CalcBlockDep(HLCStripe *prevStripe, HLCStripe *stripe) } int ifmIndex = (op->ifm.size() > 1 && op->ifm[1].address == prevOfm.address && op->ifm[1].memArea == prevOfm.memArea) ? 1 : 0; + assert(size_t(ifmIndex) < op->ifm.size()); const auto &ifm = op->ifm[ifmIndex]; int maxJobs = _arch->MaxBlockdep(); if ( ifm.address != prevOfm.address || ifm.memArea != prevOfm.memArea ) @@ -1162,64 +1166,39 @@ void EthosU55RCSGenerator::UpdateMemoryAccesses(const MemoryAccesses &memoryAcce } } -// Inserts DMA commands for copying LUTs from constant memory -// to LUT memory -std::vector> -EthosU55RCSGenerator::InsertLUTDMACommands(std::vector> &cmds) +// Inserts DMA commands for copying LUTs from constant memory to LUT memory +void EthosU55RCSGenerator::InsertLUTDMACommand( + int index, const HLCStripe *stripe, Temporaries &temps, std::vector &emitted) { - std::vector> result; int lutSlotSize = _arch->_shram.lutSlotSize; - int slots = (_arch->_shram.bankSizeBytes * _arch->_shram.lutBanks) / lutSlotSize; - std::vector lutSlots(slots); - int timestamp = 0; - result.reserve(cmds.size()); - for ( auto &hlc : cmds ) - { - ++timestamp; - if ( hlc->IsStripe() ) - { - auto stripe = static_cast(hlc.get()); - auto op = stripe->operation; - auto config = static_cast(op->config); - if ( op->type == OpType::LUT || (!op->subOps.empty() && op->subOps[0].type == OpType::LUT) ) - { - const auto &srcTens = op->type == OpType::LUT ? op->parameters.lut : op->subOps[0].parameters.lut; - assert(config->_layout.lutStart > 0); - assert(srcTens.sizeBytes % lutSlotSize == 0); - bool alreadyInLutMem; - int sizeInSlots = srcTens.sizeBytes / lutSlotSize; - int slot = AllocateLutSlot(lutSlots, op.get(), sizeInSlots, timestamp, alreadyInLutMem); - _stripeToLutSlot[stripe] = slot; - - if ( !alreadyInLutMem ) - { - auto dma = std::make_unique(); - dma->srcMemArea = srcTens.memArea; - dma->srcAddress = srcTens.address; - dma->length = srcTens.sizeBytes; - dma->destMemArea = _arch->LUTMemory(); - dma->destAddress = _arch->_shram.bankSizeBytes * config->_layout.lutStart + slot * lutSlotSize; - result.push_back(std::move(dma)); - } - } - else if ( _arch->_shram.reservedEndBanks == 0 ) - { - // LUT is overwritten by SHRAM accumulator buffers; clear slots - for ( auto &slot : lutSlots ) - { - slot.hlcOp = nullptr; - slot.lastUsed = 0; - } - } - } - result.push_back(std::move(hlc)); + auto op = stripe->operation; + auto config = static_cast(op->config); + + assert(op->type == OpType::LUT || (!op->subOps.empty() && op->subOps[0].type == OpType::LUT)); + + const auto &srcTens = op->type == OpType::LUT ? op->parameters.lut : op->subOps[0].parameters.lut; + assert(config->_layout.lutStart > 0); + assert(srcTens.sizeBytes % lutSlotSize == 0); + bool alreadyInLutMem; + int sizeInSlots = srcTens.sizeBytes / lutSlotSize; + int slot = AllocateLutSlot(_lutSlots, op.get(), sizeInSlots, index, alreadyInLutMem); + _stripeToLutSlot[stripe] = slot; + + if ( !alreadyInLutMem ) + { + auto dma = std::make_unique(); + dma->srcMemArea = srcTens.memArea; + dma->srcAddress = srcTens.address; + dma->length = srcTens.sizeBytes; + dma->destMemArea = _arch->LUTMemory(); + dma->destAddress = _arch->_shram.bankSizeBytes * config->_layout.lutStart + slot * lutSlotSize; + emitted.push_back(dma.get()); + temps.cmds.push_back(std::move(dma)); } - return result; } // Inserts DMA commands to handle TILE operations -std::vector> -EthosU55RCSGenerator::InsertTileDMACommands(std::vector> &cmds) +void EthosU55RCSGenerator::InsertTileDMACommand(const HLCStripe *stripe, Temporaries &temps, std::vector &emitted) { // reshape to 3D-tensor where the width-axis is being tiled static auto reshapeFunc = [](Shape &shape, int tiledAxis) @@ -1240,55 +1219,46 @@ EthosU55RCSGenerator::InsertTileDMACommands(std::vector> result; - for ( auto &hlc : cmds ) - { - if ( hlc->IsStripe() ) - { - auto stripe = static_cast(hlc.get()); - auto op = stripe->operation; - if ( op->type == OpType::Tile ) - { - auto &ifm = op->ifm[0]; - auto &ofm = op->ofm; + auto op = stripe->operation; + assert(op->type == OpType::Tile); - assert(ifm.format == TensorFormat::NHWC); - assert(ofm.format == TensorFormat::NHWC); + auto &ifm = op->ifm[0]; + auto &ofm = op->ofm; - const auto &tileParams = op->parameters.tile; + assert(ifm.format == TensorFormat::NHWC); + assert(ofm.format == TensorFormat::NHWC); - reshapeFunc(ifm.shape, tileParams.axis); - reshapeFunc(ofm.shape, tileParams.axis); + const auto &tileParams = op->parameters.tile; - int srcOffset = 0; - int dstOffset = 0; - int elemSize = DataTypeSizeBits(ifm.dataType) / 8; - int rowBytes = ifm.shape[2] * ifm.shape[3] * elemSize; - // each row in the IFM is copied separately - // and duplicated based on the multiplier attribute. - for ( int h = 0; h < ifm.shape.Height(); h++ ) - { - for ( int i = 0; i < tileParams.multiplier; i++ ) - { - auto dma = std::make_unique(); - dma->srcMemArea = ifm.memArea; - dma->srcAddress = ifm.address + srcOffset; - dma->length = rowBytes; - dma->destMemArea = ofm.memArea; - dma->destAddress = ofm.address + dstOffset; - result.push_back(std::move(dma)); - dstOffset += rowBytes; - } - srcOffset += rowBytes; - } - continue; - } + reshapeFunc(ifm.shape, tileParams.axis); + reshapeFunc(ofm.shape, tileParams.axis); + + int srcOffset = 0; + int dstOffset = 0; + int elemSize = DataTypeSizeBits(ifm.dataType) / 8; + int rowBytes = ifm.shape[2] * ifm.shape[3] * elemSize; + // each row in the IFM is copied separately + // and duplicated based on the multiplier attribute. + for ( int h = 0; h < ifm.shape.Height(); h++ ) + { + for ( int i = 0; i < tileParams.multiplier; i++ ) + { + auto dma = std::make_unique(); + dma->srcMemArea = ifm.memArea; + dma->srcAddress = ifm.address + srcOffset; + dma->length = rowBytes; + dma->destMemArea = ofm.memArea; + dma->destAddress = ofm.address + dstOffset; + emitted.push_back(dma.get()); + temps.cmds.push_back(std::move(dma)); + dstOffset += rowBytes; } - result.push_back(std::move(hlc)); + srcOffset += rowBytes; } - return result; } + + //---------------------------------------------------------------------- // Operations //---------------------------------------------------------------------- @@ -1385,7 +1355,7 @@ void EthosU55RCSGenerator::GenerateConvolutionOp(const HLCStripe *stripe, Memory } // MaxPool/AvgPool/ResizeBilinear or operations that are mapped to AvgPool -void EthosU55RCSGenerator::GeneratePoolingOp(HLCStripe *stripe, MemoryAccesses &memoryAccesses) +void EthosU55RCSGenerator::GeneratePoolingOp(const HLCStripe *stripe, MemoryAccesses &memoryAccesses) { auto op = stripe->operation.get(); auto pad = stripe->padding; @@ -1404,7 +1374,7 @@ void EthosU55RCSGenerator::GeneratePoolingOp(HLCStripe *stripe, MemoryAccesses & } // Elementwise operations -void EthosU55RCSGenerator::GenerateElementwiseOp(HLCStripe *stripe, MemoryAccesses &memoryAccesses) +void EthosU55RCSGenerator::GenerateElementwiseOp(const HLCStripe *stripe, MemoryAccesses &memoryAccesses) { auto op = stripe->operation.get(); auto opType = op->type; @@ -1436,19 +1406,23 @@ void EthosU55RCSGenerator::GenerateElementwiseOp(HLCStripe *stripe, MemoryAccess auto opToScale = GenerateScalingForElementwise(op, ifmIndex); GenerateCommon(stripe, useGlobalScale, opToScale, memoryAccesses, ifmIndex); int ifm2Index = 1 - ifmIndex; - bool isScalar = IsScalar(op->ifm[ifm2Index], scalarValue); - GenerateIFM2(opType, op->ifm[ifm2Index], stripe->ifmAreas[ifm2Index], isScalar, scalarValue); + assert(size_t(ifm2Index) < stripe->ifmAreas.size()); + const HLCFeatureMap &ifm2 = op->ifm.at(ifm2Index); + bool isScalar = IsScalar(ifm2, scalarValue); + GenerateIFM2(opType, ifm2, stripe->ifmAreas[ifm2Index], isScalar, scalarValue); if ( !isScalar ) { - memoryAccesses.push_back(ToMemoryAccess(op->ifm[ifm2Index], stripe->ifmAreas[ifm2Index], AccessDirection::Read)); + memoryAccesses.push_back(ToMemoryAccess(ifm2, stripe->ifmAreas[ifm2Index], AccessDirection::Read)); } - GenerateIFM2Precision(op->ifm[ifm2Index]); + GenerateIFM2Precision(ifm2); GenerateIFM2Broadcast(ifmShape, ifm2Shape, reversedOperands, isScalar); } } -bool EthosU55RCSGenerator::GenerateStripe(HLCStripe *stripe, MemoryAccesses &memoryAccesses) +bool EthosU55RCSGenerator::GenerateStripe(const HLCStripe *stripe, const HLCStripe *prevStripe, AccessTracking &accesses) { + MemoryAccesses memoryAccesses; + auto opType = stripe->operation->type; EthosU55NpuOp npuOp = ArchEthosU55::GetHWOp(opType); if ( npuOp == EthosU55NpuOp::Pooling || npuOp == EthosU55NpuOp::ReduceSum ) @@ -1472,12 +1446,21 @@ bool EthosU55RCSGenerator::GenerateStripe(HLCStripe *stripe, MemoryAccesses &mem EthosU55OpConfig *config = static_cast(stripe->operation->config); GenerateBlockConfig(config); GenerateShramRegisters(config, stripe->operation->ifm.size() >= 2); + + // BLOCKDEP register tracking + int blockdep = CalcBlockDep(prevStripe, stripe); + Emit(isa::npu_set_blockdep_t(blockdep)); + GenerateWaits(false, memoryAccesses, accesses.outstandingDmaAccesses); + UpdateMemoryAccesses(memoryAccesses, accesses.outstandingNpuAccesses, accesses.maxOutstandingKernelOps); + GenerateOperationCode(stripe->operation->type); return true; } // Generates register commands for DMA operations -void EthosU55RCSGenerator::GenerateDMA(const HLCDMA *dma, MemoryAccesses &memoryAccesses) +void EthosU55RCSGenerator::GenerateDMA(const HLCDMA *dma, AccessTracking &accesses) { + MemoryAccesses memoryAccesses; + auto srcRegionMode = dma_region_mode::EXTERNAL; auto destRegionMode = dma_region_mode::EXTERNAL; if ( dma->destMemArea == _arch->LUTMemory() ) @@ -1492,67 +1475,113 @@ void EthosU55RCSGenerator::GenerateDMA(const HLCDMA *dma, MemoryAccesses &memory Emit(isa::npu_set_dma0_dst_region_t(ToRegion(dma->destMemArea), destRegionMode, strideMode)); Emit(isa::npu_set_dma0_dst_t(dma->destAddress)); Emit(isa::npu_set_dma0_len_t(dma->length)); + + // Track memory accesses memoryAccesses.emplace_back(AccessDirection::Read, dma->srcMemArea, dma->srcAddress, dma->srcAddress + dma->length); memoryAccesses.emplace_back(AccessDirection::Write, dma->destMemArea, dma->destAddress, dma->destAddress + dma->length); + GenerateWaits(false, memoryAccesses, accesses.outstandingDmaAccesses); + GenerateWaits(true, memoryAccesses, accesses.outstandingNpuAccesses); + UpdateMemoryAccesses(memoryAccesses, accesses.outstandingDmaAccesses, accesses.maxOutstandingDMAOps); + + Emit(isa::npu_op_dma_start_t()); } +void EthosU55RCSGenerator::PrepareCommand(int index, HighLevelCommand *cmd, Temporaries &temps, std::vector &emitted) +{ + emitted.clear(); + + if ( cmd->IsStripe() ) + { + HLCStripe *stripe = static_cast(cmd); + auto op = stripe->operation; + if ( op->type == OpType::Tile ) + { + InsertTileDMACommand(stripe, temps, emitted); + return; // Return early to replace original op + } + else if ( op->type == OpType::LUT || (!op->subOps.empty() && op->subOps[0].type == OpType::LUT) ) + { + InsertLUTDMACommand(index, stripe, temps, emitted); + } + else if ( _arch->_shram.reservedEndBanks == 0 ) + { + // LUT is overwritten by SHRAM accumulator buffers; clear slots + for ( auto &slot : _lutSlots ) + { + slot.hlcOp = nullptr; + slot.lastUsed = 0; + } + } + } + + // Emit original op + emitted.push_back(cmd); +} + + std::vector EthosU55RCSGenerator::GenerateCommandStream(std::vector> &highLevelCommandStream, std::vector> *cmdRanges, bool verbose) { _emit.Clear(); _stripeToLutSlot.clear(); + // Clear lut slots at start of command stream generation + for ( auto &slot : _lutSlots ) + { + slot.hlcOp = nullptr; + slot.lastUsed = 0; + } + GenerateInitialRegisterSetup(); - auto cmds = InsertLUTDMACommands(highLevelCommandStream); - cmds = InsertTileDMACommands(cmds); - std::deque outstandingDmaAccesses; - std::deque outstandingNpuAccesses; - int maxOutstandingDMAOps = _arch->MaxOutstandingDMAOps(); - int maxOutstandingKernelOps = _arch->MaxOutstandingKernelOps(); - HLCStripe *prevOp = nullptr; + + AccessTracking accesses; + accesses.maxOutstandingDMAOps = _arch->MaxOutstandingDMAOps(); + accesses.maxOutstandingKernelOps = _arch->MaxOutstandingKernelOps(); + + const HLCStripe *prevStripe = nullptr; std::vector> debugInfo; - for ( auto &hlc : cmds ) + + Temporaries temporaries; + std::vector emitted(4); + + int cmdIndex = 0; + for ( const auto &cmd : highLevelCommandStream ) { int emitStart = _emit.Position(); - if ( hlc->IsStripe() ) + + PrepareCommand(cmdIndex, cmd.get(), temporaries, emitted); + + for ( auto hlc : emitted ) { - MemoryAccesses memoryAccesses; - auto stripe = static_cast(hlc.get()); - if ( verbose ) + if ( hlc->IsStripe() ) { - debugInfo.emplace_back(emitStart, stripe->operation->ToString()); - } - if ( !GenerateStripe(stripe, memoryAccesses) ) - { - return std::vector(); + auto stripe = static_cast(hlc); + if ( verbose ) + { + debugInfo.emplace_back(_emit.Position(), stripe->operation->ToString()); + } + if ( !GenerateStripe(stripe, prevStripe, accesses) ) + { + return std::vector(); + } + prevStripe = stripe; } - // BLOCKDEP register - int blockdep = CalcBlockDep(prevOp, stripe); - Emit(isa::npu_set_blockdep_t(blockdep)); - GenerateWaits(false, memoryAccesses, outstandingDmaAccesses); - UpdateMemoryAccesses(memoryAccesses, outstandingNpuAccesses, maxOutstandingKernelOps); - GenerateOperationCode(stripe->operation->type); - prevOp = stripe; - // Return command mapping information to the caller - int emitEnd = _emit.Position(); - if ( cmdRanges ) + else { - cmdRanges->emplace_back(stripe->operation->_srcKey, emitStart, emitEnd); + auto dma = static_cast(hlc); + if ( verbose ) + { + debugInfo.emplace_back(_emit.Position(), dma->ToString()); + } + GenerateDMA(dma, accesses); } } - else + + // Return command mapping information to the caller + if ( cmdRanges && cmd->IsStripe() ) { - MemoryAccesses dmaAccesses; - auto dma = static_cast(hlc.get()); - if ( verbose ) - { - debugInfo.emplace_back(emitStart, dma->ToString()); - } - GenerateDMA(static_cast(hlc.get()), dmaAccesses); - GenerateWaits(false, dmaAccesses, outstandingDmaAccesses); - GenerateWaits(true, dmaAccesses, outstandingNpuAccesses); - UpdateMemoryAccesses(dmaAccesses, outstandingDmaAccesses, maxOutstandingDMAOps); - Emit(isa::npu_op_dma_start_t()); + cmdRanges->emplace_back(static_cast(cmd.get())->operation->_srcKey, emitStart, _emit.Position()); } + cmdIndex++; } Emit(isa::npu_op_stop_t(0xFFFF)); if ( verbose ) diff --git a/ethosu/regor/architecture/ethosu55/ethos_u55_register_cs_generator.hpp b/ethosu/regor/architecture/ethosu55/ethos_u55_register_cs_generator.hpp index ec3c6d4c3415ef23bb923ef585feff2b0db9f6a6..f4edae0430f2461a6d61fd12595df98a1694e50d 100644 --- a/ethosu/regor/architecture/ethosu55/ethos_u55_register_cs_generator.hpp +++ b/ethosu/regor/architecture/ethosu55/ethos_u55_register_cs_generator.hpp @@ -125,6 +125,13 @@ struct LutSlot /// class EthosU55RCSGenerator : public EthosURegisterCSGenerator { +private: + ArchEthosU55 *_arch; + // For stripes that use LUT: the LUT slot to be used + std::unordered_map _stripeToLutSlot; + std::vector _lutSlots; + EthosU55Emitter _emit; + public: EthosU55RCSGenerator(ArchEthosU55 *arch); @@ -186,7 +193,7 @@ protected: // (in that case, the very last job is added last) void GetJobs(const Box &area, const Shape &block, int nrJobsToGet, bool fromStart, std::vector &jobs); // Calculates the value for the BLOCKDEP register - int CalcBlockDep(HLCStripe *prevStripe, HLCStripe *stripe); + int CalcBlockDep(const HLCStripe *prevStripe, const HLCStripe *stripe); @@ -225,16 +232,30 @@ protected: void GenerateWaits(bool isKernelWait, const MemoryAccesses &memoryAccesses, std::deque &outstandingAccesses); // Save current memory accesses to accessesToUpdate void UpdateMemoryAccesses(const MemoryAccesses &memoryAccesses, std::deque &accessesToUpdate, int maxWaits); - // Inserts DMA commands for copying LUTs from constant memory - // to LUT memory - std::vector> InsertLUTDMACommands(std::vector> &cmds); + + struct Temporaries + { + std::vector> cmds; + std::vector> configs; + }; + + // Inserts DMA commands for copying LUTs from constant memory to LUT memory + void InsertLUTDMACommand(int index, const HLCStripe *stripe, Temporaries &temps, std::vector &emitted); // Inserts DMA commands to handle TILE operations - virtual std::vector> InsertTileDMACommands(std::vector> &cmds); + virtual void InsertTileDMACommand(const HLCStripe *stripe, Temporaries &temps, std::vector &emitted); //---------------------------------------------------------------------- // Operations //---------------------------------------------------------------------- + struct AccessTracking + { + std::deque outstandingNpuAccesses; + std::deque outstandingDmaAccesses; + int maxOutstandingDMAOps; + int maxOutstandingKernelOps; + }; + // Generates NPU_OP_* command void GenerateOperationCode(OpType opType); void GenerateCommon(const HLCStripe *stripe, bool useGlobalScale, RCSIfmScaleMode opToScale, @@ -242,12 +263,13 @@ protected: // Conv2D/Depthwise operations void GenerateConvolutionOp(const HLCStripe *stripe, MemoryAccesses &memoryAccesses); // MaxPool/AvgPool/ResizeBilinear or operations that are mapped to AvgPool - void GeneratePoolingOp(HLCStripe *stripe, MemoryAccesses &memoryAccesses); + void GeneratePoolingOp(const HLCStripe *stripe, MemoryAccesses &memoryAccesses); // Elementwise operations - void GenerateElementwiseOp(HLCStripe *stripe, MemoryAccesses &memoryAccesses); - bool GenerateStripe(HLCStripe *stripe, MemoryAccesses &memoryAccesses); + void GenerateElementwiseOp(const HLCStripe *stripe, MemoryAccesses &memoryAccesses); + bool GenerateStripe(const HLCStripe *stripe, const HLCStripe *prevStripe, AccessTracking &accesses); + void PrepareCommand(int index, HighLevelCommand *cmd, Temporaries &temps, std::vector &emitted); // Generates register commands for DMA operations - virtual void GenerateDMA(const HLCDMA *dma, MemoryAccesses &memoryAccesses); + virtual void GenerateDMA(const HLCDMA *dma, AccessTracking &accesses); virtual void GenerateInitialRegisterSetup() { @@ -260,12 +282,6 @@ public: static uint32_t IdRegister(); static bool IsSupportedElementwise(const OpType opType); - -private: - ArchEthosU55 *_arch; - // For stripes that use LUT: the LUT slot to be used - std::unordered_map _stripeToLutSlot; - EthosU55Emitter _emit; }; } // namespace regor diff --git a/ethosu/regor/architecture/ethosu65/ethos_u65_register_cs_generator.cpp b/ethosu/regor/architecture/ethosu65/ethos_u65_register_cs_generator.cpp index edda5ad2dc9fe34ef99ea1ab0b1890c2ddc9a099..337568cfb4aa78de1c58f7ab96ebac92f6d82640 100644 --- a/ethosu/regor/architecture/ethosu65/ethos_u65_register_cs_generator.cpp +++ b/ethosu/regor/architecture/ethosu65/ethos_u65_register_cs_generator.cpp @@ -31,9 +31,8 @@ EthosU65RCSGenerator::EthosU65RCSGenerator(ArchEthosU65 *arch) : EthosU55RCSGene { } -// Converts TILE operations into 3D (or 2D) DMA operations -std::vector> -EthosU65RCSGenerator::InsertTileDMACommands(std::vector> &cmds) + +void EthosU65RCSGenerator::InsertTileDMACommand(const HLCStripe *stripe, Temporaries &temps, std::vector &emitted) { // reshape to 3D-tensor where the width-axis is being tiled static auto reshapeFunc = [](Shape &shape, int tiledAxis) @@ -54,70 +53,62 @@ EthosU65RCSGenerator::InsertTileDMACommands(std::vector> result; - for ( auto &hlc : cmds ) + auto op = stripe->operation; + assert(op->type == OpType::Tile); + + // convert tile-operation to multiple DMA operations + auto &ifm = op->ifm[0]; + auto &ofm = op->ofm; + // max-height for 2D/3D DMA operations + constexpr int maxHeight = (1 << 16) - 1; + + assert(ifm.format == TensorFormat::NHWC); + assert(ofm.format == TensorFormat::NHWC); + + const auto &tileParams = op->parameters.tile; + + reshapeFunc(ifm.shape, tileParams.axis); + reshapeFunc(ofm.shape, tileParams.axis); + + int elemSize = DataTypeSizeBits(ifm.dataType) / 8; + auto srcStrides = Shape::GetStridesForShape(ifm.shape, {1, 1, 1, elemSize}); + auto dstStrides = Shape::GetStridesForShape(ofm.shape, {1, 1, 1, elemSize}); + + int srcheightOffset = 0; + int dstheightOffset = 0; + int height = ifm.shape.Height(); + while ( height > 0 ) { - if ( hlc->IsStripe() ) + int heightSlice = std::min(height, maxHeight); + + // create 2D/3D DMA that copies ifm to ofm + for ( int i = 0; i < tileParams.multiplier; i++ ) { - auto stripe = static_cast(hlc.get()); - auto op = stripe->operation; - if ( op->type == OpType::Tile ) - { - // convert tile-operation to multiple DMA operations - auto &ifm = op->ifm[0]; - auto &ofm = op->ofm; - // max-height for 2D/3D DMA operations - constexpr int maxHeight = (1 << 16) - 1; - - assert(ifm.format == TensorFormat::NHWC); - assert(ofm.format == TensorFormat::NHWC); - - const auto &tileParams = op->parameters.tile; - - reshapeFunc(ifm.shape, tileParams.axis); - reshapeFunc(ofm.shape, tileParams.axis); - - int elemSize = DataTypeSizeBits(ifm.dataType) / 8; - auto srcStrides = Shape::GetStridesForShape(ifm.shape, {1, 1, 1, elemSize}); - auto dstStrides = Shape::GetStridesForShape(ofm.shape, {1, 1, 1, elemSize}); - - int srcheightOffset = 0; - int dstheightOffset = 0; - int height = ifm.shape.Height(); - while ( height > 0 ) - { - int heightSlice = std::min(height, maxHeight); - - // create 2D/3D DMA that copies ifm to ofm - for ( int i = 0; i < tileParams.multiplier; i++ ) - { - int addrOffset = i * ifm.shape.Width() * srcStrides.Width(); - auto dma = std::make_unique(); - dma->srcMemArea = ifm.memArea; - dma->srcAddress = ifm.address + srcheightOffset; - dma->srcStrides = srcStrides; - dma->length = ifm.shape.Depth() * elemSize; - dma->sizes = Shape(heightSlice, ifm.shape.Width()); - dma->destMemArea = ofm.memArea; - dma->destAddress = ofm.address + dstheightOffset + addrOffset; - dma->destStrides = dstStrides; - result.push_back(std::move(dma)); - } - height -= heightSlice; - srcheightOffset += heightSlice * srcStrides.Height(); - dstheightOffset += heightSlice * dstStrides.Height(); - } - continue; - } + int addrOffset = i * ifm.shape.Width() * srcStrides.Width(); + auto dma = std::make_unique(); + dma->srcMemArea = ifm.memArea; + dma->srcAddress = ifm.address + srcheightOffset; + dma->srcStrides = srcStrides; + dma->length = ifm.shape.Depth() * elemSize; + dma->sizes = Shape(heightSlice, ifm.shape.Width()); + dma->destMemArea = ofm.memArea; + dma->destAddress = ofm.address + dstheightOffset + addrOffset; + dma->destStrides = dstStrides; + emitted.push_back(dma.get()); + temps.cmds.push_back(std::move(dma)); } - result.push_back(std::move(hlc)); + height -= heightSlice; + srcheightOffset += heightSlice * srcStrides.Height(); + dstheightOffset += heightSlice * dstStrides.Height(); } - return result; } + // Generates register commands for DMA operations -void EthosU65RCSGenerator::GenerateDMA(const HLCDMA *dma, MemoryAccesses &memoryAccesses) +void EthosU65RCSGenerator::GenerateDMA(const HLCDMA *dma, AccessTracking &accesses) { + MemoryAccesses memoryAccesses; + auto srcRegionMode = dma_region_mode::EXTERNAL; auto destRegionMode = dma_region_mode::EXTERNAL; @@ -181,6 +172,13 @@ void EthosU65RCSGenerator::GenerateDMA(const HLCDMA *dma, MemoryAccesses &memory memoryAccesses.emplace_back(AccessDirection::Read, dma->srcMemArea, dma->srcAddress, dma->srcAddress + dma->srcStrides[0]); memoryAccesses.emplace_back(AccessDirection::Write, dma->destMemArea, dma->destAddress, dma->destAddress + dma->destStrides[0]); } + + // Track memory accesses + GenerateWaits(false, memoryAccesses, accesses.outstandingDmaAccesses); + GenerateWaits(true, memoryAccesses, accesses.outstandingNpuAccesses); + UpdateMemoryAccesses(memoryAccesses, accesses.outstandingDmaAccesses, accesses.maxOutstandingDMAOps); + + Emit(isa::npu_op_dma_start_t()); } void EthosU65RCSGenerator::GenerateInitialRegisterSetup() diff --git a/ethosu/regor/architecture/ethosu65/ethos_u65_register_cs_generator.hpp b/ethosu/regor/architecture/ethosu65/ethos_u65_register_cs_generator.hpp index e9d898564376a7d4ce44a4d5c48ad26319121bd9..3a8e33675c24b2deb7758d7b7e2f07d93f97cd8b 100644 --- a/ethosu/regor/architecture/ethosu65/ethos_u65_register_cs_generator.hpp +++ b/ethosu/regor/architecture/ethosu65/ethos_u65_register_cs_generator.hpp @@ -34,10 +34,10 @@ public: protected: // Converts TILE operations to DMA commands - std::vector> InsertTileDMACommands(std::vector> &cmds) override; + void InsertTileDMACommand(const HLCStripe *stripe, Temporaries &temps, std::vector &emitted) override; // Generate register commands for DMA operations - void GenerateDMA(const HLCDMA *dma, MemoryAccesses &memoryAccesses) override; + void GenerateDMA(const HLCDMA *dma, AccessTracking &accesses) override; void GenerateInitialRegisterSetup() override; private: diff --git a/ethosu/regor/architecture/ethosu85/ethos_u85.cpp b/ethosu/regor/architecture/ethosu85/ethos_u85.cpp index 0bc746bea5a0dfbd0230097f4a4fbdd02daac0a4..bcb1b7c35280b36de1202bdd2d9a30c4d57a7127 100644 --- a/ethosu/regor/architecture/ethosu85/ethos_u85.cpp +++ b/ethosu/regor/architecture/ethosu85/ethos_u85.cpp @@ -1449,7 +1449,7 @@ bool EthosU85OpGroup::Fuse(const ArchitectureOpGroupQuery &op, const std::vector } // Can't fuse a transpose type that's not supported by primaryOp in opgroup - if ( !_arch->_constraints->SupportsTranspose(_ops[0].type, op.ofm.transpose) ) + if ( _arch->_constraints->SupportsTranspose(_ops[0].type, op.ofm.transpose) == TransposeSupport::None ) { return false; } @@ -1774,7 +1774,7 @@ bool EthosU85OpGroup::CanRunOnNPU(const ArchitectureOpGroupQuery &op) if ( op.type == OpType::Transpose ) { - return _arch->_constraints->SupportsTranspose(OpType::MemoryCopy, op.ofm.transpose); + return _arch->_constraints->SupportsTranspose(OpType::MemoryCopy, op.ofm.transpose) != TransposeSupport::None; } if ( op.type == OpType::Reverse ) diff --git a/ethosu/regor/architecture/ethosu85/ethos_u85_constraints.cpp b/ethosu/regor/architecture/ethosu85/ethos_u85_constraints.cpp index 339b3cd40dd9fea83f12acfe16cbbea0413b2322..888942d750b9e26f8276cac7bccbce12d7090c25 100644 --- a/ethosu/regor/architecture/ethosu85/ethos_u85_constraints.cpp +++ b/ethosu/regor/architecture/ethosu85/ethos_u85_constraints.cpp @@ -39,22 +39,30 @@ bool EthosU85Constraints::SupportsMatMul(OpType opType) return true; } -bool EthosU85Constraints::SupportsTranspose(OpType opType, TransposeType transposeType) +TransposeSupport EthosU85Constraints::SupportsTranspose(OpType opType, TransposeType transposeType) { - if ( IsNone(transposeType) ) return true; + if ( IsNone(transposeType) ) return TransposeSupport::Any; EthosU85NpuOp npuOp = ArchEthosU85::GetHWOp(opType); if ( npuOp == EthosU85NpuOp::None || npuOp == EthosU85NpuOp::Resize || npuOp == EthosU85NpuOp::Dma ) { - return false; + return TransposeSupport::None; } else if ( npuOp == EthosU85NpuOp::Elementwise ) { - return transposeType == TransposeType::None || transposeType == TransposeType::NHCW || transposeType == TransposeType::NCHW; + if ( transposeType == TransposeType::None || transposeType == TransposeType::NHCW || transposeType == TransposeType::NCHW ) + { + return TransposeSupport::Any; + } + + return TransposeSupport::None; } - return transposeType == TransposeType::None || transposeType == TransposeType::NWHC || transposeType == TransposeType::NHCW || - transposeType == TransposeType::NWCH || transposeType == TransposeType::NCHW || transposeType == TransposeType::NCWH; + if ( transposeType == TransposeType::None || transposeType == TransposeType::NWHC || transposeType == TransposeType::NHCW || + transposeType == TransposeType::NWCH || transposeType == TransposeType::NCHW || transposeType == TransposeType::NCWH ) + return TransposeSupport::Any; + + return TransposeSupport::None; } bool EthosU85Constraints::SupportsReverse(OpType opType, ReverseType reverseTypeMask) diff --git a/ethosu/regor/architecture/ethosu85/ethos_u85_constraints.hpp b/ethosu/regor/architecture/ethosu85/ethos_u85_constraints.hpp index b06795e5744bdd27bdd5bf0f5fc2e5ca3afa3617..228c08258d9acee402aca1e1bb66c313c9510bd5 100644 --- a/ethosu/regor/architecture/ethosu85/ethos_u85_constraints.hpp +++ b/ethosu/regor/architecture/ethosu85/ethos_u85_constraints.hpp @@ -29,7 +29,7 @@ public: bool SupportsLeakyRelu(bool quantized, DataType type) override; bool SupportsMatMul(OpType opType) override; - bool SupportsTranspose(OpType opType, TransposeType transposeType) override; + TransposeSupport SupportsTranspose(OpType opType, TransposeType transposeType) override; bool SupportsReverse(OpType opType, ReverseType reverseTypeMask) override; bool SupportsFusedRescale(OpType opType, TensorUsage tensorUsage, DataType fromType, DataType toType, const Quantization &quantization) override; bool SupportsRescale(DataType fromType, DataType toType) override; diff --git a/ethosu/regor/compiler/compiler.cpp b/ethosu/regor/compiler/compiler.cpp index b7fd2edc5b078572c48d39b043a427d7b49e8ea9..e81547df4ae96aeac3eacfcb4fafd9b5c1907576 100644 --- a/ethosu/regor/compiler/compiler.cpp +++ b/ethosu/regor/compiler/compiler.cpp @@ -217,13 +217,12 @@ public: }; -bool Compiler::Store(const std::vector> &graphs, +void Compiler::Store(const std::vector> &graphs, const std::vector> &tensorAddressMaps) { if ( _compilerOptions.outputFormat == OutputFormat::Raw ) { RawWriter writer; - // This will serialise multiple blobs auto buffers = writer.Serialise(graphs, tensorAddressMaps); @@ -245,8 +244,6 @@ bool Compiler::Store(const std::vector> &graphs, RawBlob *output = new RawBlob(std::move(buffer), offset, int64_t(size)); _output.push_back(output); } - - return true; } @@ -301,9 +298,18 @@ bool Compiler::Compile() } _optDb.reset(); - Store(newGraphs, tensorAddressMaps); - _builders.clear(); + + try + { + Store(newGraphs, tensorAddressMaps); + } + catch ( const std::invalid_argument &ex ) + { + SetLastError(fmt::format("Output error: {} \n", ex.what())); + return false; + } + return true; } @@ -377,23 +383,32 @@ std::unique_ptr Compiler::CompileGraph(std::unique_ptr &graph, return nullptr; } } - if ( graph->Notation() == GraphNotation::TFLite ) + + try { - // Run GraphNotation::TFLite Preprocess/optimise step + if ( graph->Notation() == GraphNotation::TFLite ) + { + // Run GraphNotation::TFLite Preprocess/optimise step + std::unique_ptr optimiser = GraphOptimiser::MakeGraphOptimiser( + GraphNotation::TFLite, _architecture->Constraints(), _graphOptimiserOptions, _optDb.get()); + if ( optimiser ) + { + optimiser->Process(graph.get()); + } + } + + // Run GraphNotation::GraphAPI Preprocess/optimise step std::unique_ptr optimiser = GraphOptimiser::MakeGraphOptimiser( - GraphNotation::TFLite, _architecture->Constraints(), _graphOptimiserOptions, _optDb.get()); + GraphNotation::GraphAPI, _architecture->Constraints(), _graphOptimiserOptions, _optDb.get()); if ( optimiser ) { optimiser->Process(graph.get()); } } - - // Run GraphNotation::GraphAPI Preprocess/optimise step - std::unique_ptr optimiser = GraphOptimiser::MakeGraphOptimiser( - GraphNotation::GraphAPI, _architecture->Constraints(), _graphOptimiserOptions, _optDb.get()); - if ( optimiser ) + catch ( const std::runtime_error &e ) { - optimiser->Process(graph.get()); + SetLastError(e.what()); + return nullptr; } // Pack/linearise graph Operations into SchedulerOperations diff --git a/ethosu/regor/compiler/compiler.hpp b/ethosu/regor/compiler/compiler.hpp index f14fde1852dc6700cae79b24bb14d6b54654a8ba..260cd3c637e2a1e7cf2328aa7bb5633cb57549da 100644 --- a/ethosu/regor/compiler/compiler.hpp +++ b/ethosu/regor/compiler/compiler.hpp @@ -94,7 +94,7 @@ public: bool LoadTosa(const void *input, size_t size); bool LoadTflite(const void *input, size_t size); - bool Store(const std::vector> &graphs, + void Store(const std::vector> &graphs, const std::vector> &tensorAddressMaps); bool Compile(); diff --git a/ethosu/regor/compiler/graphir_optimiser.cpp b/ethosu/regor/compiler/graphir_optimiser.cpp index 78de0637b869f60098c10e6be89d45a37047a94a..18b6c873ed5a8f67c8c78bab3644d38b60cdc669 100644 --- a/ethosu/regor/compiler/graphir_optimiser.cpp +++ b/ethosu/regor/compiler/graphir_optimiser.cpp @@ -173,12 +173,6 @@ Operation *GraphIrOptimiser::ConvertAttributes(Graph *const graph, Operation *co ofmConn->quantization.scales[0].shift += attr->shift; attr->shift = 0; } - else if ( opType == OpType::Transpose ) - { - const auto *attr = operation->Attribute(); - TensorConnection *ofmConn = operation->Output(TensorUsage::OFM); - ofmConn->transpose = TransposeTypeFromShape(attr->perm); - } else if ( opType == OpType::Reverse ) { // Convert TOSA axis attribute to ReverseType representation @@ -1438,6 +1432,93 @@ Operation *GraphIrOptimiser::RewriteTile(Graph *const, Operation *const operatio return returnOp; } +// Merge adjacent transposes +Operation *GraphIrOptimiser::MergeTransposes(Graph *const graph, Operation *const operation) +{ + UNUSED(graph); + Operation *returnOp = operation; + const OpType opType = operation->Type(); + if ( opType == OpType::Transpose ) + { + auto *ifmConn = operation->Input(TensorUsage::IFM); + auto *ofmConn = operation->Output(TensorUsage::OFM); + auto *ifm = ifmConn->tensor.get(); + const auto &ofm = ofmConn->tensor; + auto *prevOp = ifm->Writers().empty() ? nullptr : ifm->Writers().front().get(); + + auto *attr = operation->Attribute(); + auto curTranspose = TransposeTypeFromShape(attr->perm); + bool opHasQuant = ofmConn->quantization.IsValid() && !ofmConn->quantization.IsUnitScale(); + + // Remove no-op transposes if possible + if ( IsNone(curTranspose) ) + { + assert(ofmConn->shape == ifmConn->shape); + // Transpose is the only operator, it may be peforming memory copy duties. + if ( !prevOp && ofm->Readers().empty() ) + { + auto newOp = std::make_shared(OpType::MemoryCopy); + newOp->CopyInput(TensorUsage::IFM0, *ifmConn); + newOp->CopyOutput(TensorUsage::OFM, *ofmConn); + operation->Disconnect(); + returnOp = newOp.get(); + RecordOptimisation(operation, returnOp); + } + // Disconnect from surrounding ops, if this is a graph input + // or output it remains untouched. + else if ( ifm->IsSinglePath() && !opHasQuant && prevOp ) + { + ifm->RemoveWriter(prevOp->shared_from_this()); + prevOp->ConnectOutput(TensorUsage::OFM, ofm).Set(ofmConn->slice); + operation->Disconnect(); + returnOp = prevOp; + } + return returnOp; + } + + // Transpose is fed by a preceding transpose (single writer, single reader) + if ( prevOp && (prevOp->Type() == OpType::Transpose) && ifm->IsSinglePath() ) + { + const auto *prevConn = prevOp->Output(TensorUsage::OFM); + assert(prevConn); + + // Can't merge if predecessor reverses or reshapes + if ( prevConn->reverse != ReverseType::None || prevConn->shape != ifmConn->shape ) return returnOp; + + // Can't merge if both apply quantization + bool prevHasQuant = prevConn->quantization.IsValid() && !prevConn->quantization.IsUnitScale(); + if ( opHasQuant && prevHasQuant ) return returnOp; + + // Examine previous op's transpose + auto *prevAttr = prevOp->Attribute(); + auto prevTranspose = TransposeTypeFromShape(prevAttr->perm); + + // Apply both transposes to default axes and examine the resulting transpose + static std::array nhwcDefault = {0, 1, 2, 3, 4, 5, 6, 7}; + int activeAxes = std::min(int(nhwcDefault.size()), ifmConn->shape.Size()); + + Shape axes(nhwcDefault.data(), activeAxes); + Shape prevMapping = axes.Permute(unsigned(prevTranspose)); + Shape finalMapping = prevMapping.Permute(unsigned(curTranspose)); + TransposeType mergedTranspose = TransposeTypeFromShape(finalMapping); + + // The single merged transpose is supported + if ( _constraints->SupportsTranspose(OpType::Transpose, mergedTranspose) != TransposeSupport::None ) + { + // Change the transpose attribute on the preceding transpose and remove this one + prevAttr->perm = finalMapping; + TensorConnection &newConn = prevOp->ConnectOutput(TensorUsage::OFM, ofm); + newConn.Set(ofmConn->slice).Set(ofmConn->reverse).Set(ofmConn->shape); + if ( !prevHasQuant && opHasQuant ) newConn.Set(ofmConn->quantization); + operation->Disconnect(); + return prevOp; + } + } + } + + return returnOp; +} + // Rearrange transpose Operation *GraphIrOptimiser::RearrangeTranspose(Graph *const graph, Operation *const operation) { @@ -1465,17 +1546,17 @@ Operation *GraphIrOptimiser::RearrangeTranspose(Graph *const graph, Operation *c // 1x8x128x32 + [2, 0, 1, 3] -> 128x1x8x32 // Compact, with supported permutation vector: // 1x8x128x32 + [0, 2, 1, 3] ("NWHC") -> 1x128x8x32 + Shape perm = attr->perm; // Don't bother with rearrangement if transpose type is already supported - if ( _constraints->SupportsTranspose(OpType::MemoryCopy, ofmConn->transpose) ) + auto transposeType = TransposeTypeFromShape(perm); + if ( _constraints->SupportsTranspose(OpType::Transpose, transposeType) != TransposeSupport::None ) { return returnOp; } Shape ifmShape = ifmConn->shape; Shape ofmShape = ofmConn->shape; - Shape perm = attr->perm; - assert(perm); int ofmDim = perm.Size() - 1; for ( auto onesMask = ofmShape.EqualMask(ofmShape.WithOnes()); onesMask; onesMask >>= 1 ) { @@ -1496,7 +1577,6 @@ Operation *GraphIrOptimiser::RearrangeTranspose(Graph *const graph, Operation *c ofmDim--; } - ofmConn->transpose = TransposeTypeFromShape(perm); attr->perm = perm; ifmConn->shape = ifmShape; ofmConn->shape = ofmShape; @@ -1894,10 +1974,7 @@ Operation *GraphIrOptimiser::MoveSplitSliceToConsumer(Graph *const, Operation *c auto *ofm = ofmConn->tensor.get(); // TODO: MLBEDSW-9072: Add check that moving split to consumer is valid - - // We can only move to consumer if there is no transpose on the op that we will remove, - // otherwise we will lose that transposition. - if ( ofm->Readers().size() == 1 && IsNone(ofmConn->transpose) ) + if ( ofm->Readers().size() == 1 ) { auto cons = ofm->Readers().front(); auto consOfmConn = cons->Output(TensorUsage::OFM); @@ -1918,10 +1995,15 @@ Operation *GraphIrOptimiser::MoveSplitSliceToConsumer(Graph *const, Operation *c ifmShapeEqual = consIfm1Conn->shape == ofmConn->shape; } + TransposeType consumerTranspose = TransposeType::None; + if ( cons->Type() == OpType::Transpose ) + { + consumerTranspose = TransposeTypeFromShape(cons->Attribute()->perm); + } + // We can only move to consumer if there is no transpose on the op that we move to, // otherwise the IFM shape may change and transposition will be wrong. - if ( !IsReshape(cons->Type()) && ofmConn->shape == Shape::PadAxes(ofm->StorageShape(), 4, 1) && - IsNone(consOfmConn->transpose) && ifmShapeEqual ) + if ( !IsReshape(cons->Type()) && ofmConn->shape == Shape::PadAxes(ofm->StorageShape(), 4, 1) && IsNone(consumerTranspose) && ifmShapeEqual ) { // Split/Slice can be performed by tensor consumer MoveToConsumer(operation, cons.get()); diff --git a/ethosu/regor/compiler/graphir_optimiser.hpp b/ethosu/regor/compiler/graphir_optimiser.hpp index cd3720ffe72184ed28b92729dab15bdcfc974c2c..86dec4a3c07a6765d76e7954d3ceabac6a23a9c4 100644 --- a/ethosu/regor/compiler/graphir_optimiser.hpp +++ b/ethosu/regor/compiler/graphir_optimiser.hpp @@ -67,6 +67,7 @@ private: Operation *RewriteDepthwise(Graph *const graph, Operation *const operation); Operation *RewriteTransposeConvOFMPadding(Graph *const graph, Operation *const operation); Operation *OptimiseElementwise(Graph *const graph, Operation *const operation); + Operation *MergeTransposes(Graph *const graph, Operation *const operation); Operation *RearrangeTranspose(Graph *const graph, Operation *const operation); Operation *ReshapeReverse(Graph *const graph, Operation *const operation); void MoveToConsumer(const Operation *const operation, Operation *const cons); @@ -144,6 +145,7 @@ private: &GraphIrOptimiser::RewriteDepthwise, &GraphIrOptimiser::RewriteTransposeConvOFMPadding, &GraphIrOptimiser::OptimiseElementwise, + &GraphIrOptimiser::MergeTransposes, &GraphIrOptimiser::RearrangeTranspose, &GraphIrOptimiser::ReshapeReverse, &GraphIrOptimiser::UnrollConv diff --git a/ethosu/regor/compiler/operation.cpp b/ethosu/regor/compiler/operation.cpp index 7ea54f59637b51c5c5e6914b2108a5544e2217f7..18937b1a08cfebf63bde0f53b3f79a0e5640adf9 100644 --- a/ethosu/regor/compiler/operation.cpp +++ b/ethosu/regor/compiler/operation.cpp @@ -73,8 +73,7 @@ void Operation::CopyInput(TensorUsage usage, const TensorConnection &tensorConne ConnectInput(usage, tensorConnection.tensor) .Set(tensorConnection.shape) .Set(tensorConnection.slice) - .Set(tensorConnection.quantization) - .Set(tensorConnection.transpose); + .Set(tensorConnection.quantization); } TensorConnection &Operation::ConnectInput(TensorUsage usage, const std::shared_ptr &tensor) @@ -107,8 +106,7 @@ void Operation::CopyOutput(TensorUsage usage, const TensorConnection &tensorConn ConnectOutput(usage, tensorConnection.tensor) .Set(tensorConnection.shape) .Set(tensorConnection.slice) - .Set(tensorConnection.quantization) - .Set(tensorConnection.transpose); + .Set(tensorConnection.quantization); } TensorConnection &Operation::ConnectOutput(TensorUsage usage, const std::shared_ptr &tensor) diff --git a/ethosu/regor/compiler/operation.hpp b/ethosu/regor/compiler/operation.hpp index 9922ed5610bbf55e2eecb778c1719cc9792adf26..121c3053f83c9d50f6a429c2daec12572006cb47 100644 --- a/ethosu/regor/compiler/operation.hpp +++ b/ethosu/regor/compiler/operation.hpp @@ -73,7 +73,6 @@ struct TensorConnection // Reading: Split, SplitV, Unpack, Slice, and StridedSlice TensorSlice slice; Quantization quantization; - TransposeType transpose = TransposeType::None; ReverseType reverse = ReverseType::None; TensorConnection &Set(const Shape &s) @@ -91,11 +90,6 @@ struct TensorConnection quantization = q; return *this; } - TensorConnection &Set(const TransposeType &t) - { - transpose = t; - return *this; - } TensorConnection &Set(const ReverseType &r) { reverse = r; diff --git a/ethosu/regor/compiler/operation_util.hpp b/ethosu/regor/compiler/operation_util.hpp index cf4f763a241a7e72fb00dc29c224b9aac405c8a6..1bf62b5a531fb2d451966ed46902b69e5735e05a 100644 --- a/ethosu/regor/compiler/operation_util.hpp +++ b/ethosu/regor/compiler/operation_util.hpp @@ -23,6 +23,7 @@ #include "common/buffer_view.hpp" #include "operation.hpp" #include "quantization.hpp" +#include "shape_util.hpp" #include "tensor.hpp" #include @@ -260,25 +261,6 @@ inline Operation *CreateRescaleAdd(const std::shared_ptr &ifm, const std return op; } -// Convert a permutation shape (up to 8 elements) to a TransposeType -// For example: -// [0, 1, 2, 3] -> 0x0123 ("NHWC") -// [0, 1, 2] -> 0x0123 ("NHWC") -// [0, 1] -> 0x0123 ("NHWC") -// [0] -> 0x0123 ("NHWC") -// [0, 2, 1, 3] -> 0x0213 ("NWHC") -// [1, 0, 2] -> 0x0213 ("NWHC") -inline TransposeType TransposeTypeFromShape(const Shape &perm) -{ - const int n = perm.Size(); - // We can only handle permutation vectors up 8 elements - if ( n > 8 ) throw std::invalid_argument("Permutation shape has more than 8 elements"); - uint32_t mask = perm.ToMask(); - uint32_t offset = 0x76543210 & ~(0xFFFFFFFF >> (4 * (8 - n))); - uint32_t mask8D = mask + offset; - return TransposeType(mask8D); -} - inline TransposeType CalculateTransposeType(const Operation &operation) { const auto *paramsConn = operation.Input(TensorUsage::Params); @@ -298,34 +280,6 @@ inline bool IsScalingValidAndEqual(const TensorConnection &a, const TensorConnec a.quantization.zeroPoints == b.quantization.zeroPoints); } -// Reshape for example (A, B, N, H, W, C) + (3, 2, 1) -> (A*B*N, H*W, C) -inline Shape ReshapeTo3D(const Shape &shape, const Shape &axes, int minAxis = 1) -{ - assert(axes.Size() == 3); - assert(axes[0] + axes[1] + axes[2] == shape.Size()); - int h = std::max(minAxis, shape.AxisProduct(0, axes[0])); - int w = std::max(minAxis, shape.AxisProduct(axes[0], axes[0] + axes[1])); - int c = std::max(minAxis, shape.AxisProduct(axes[0] + axes[1], axes[0] + axes[1] + axes[2])); - return Shape(h, w, c); -} - -// Reshape for example (B, N, H, W, C) + W -> (B*N*H, W, C) -inline Shape ReshapeTo3DAroundAxis(const Shape &shape, int axis, int minAxis = 1) -{ - assert(axis >= 0); - assert(axis < shape.Size()); - int outer = axis; - int inner = shape.Size() - axis - 1; - return ReshapeTo3D(shape, {outer, 1, inner}, minAxis); -} - -// Reshape (B, N, H, W, C) -> (B, N*H*W, C) -inline Shape ReshapeTo3DAroundEdges(const Shape &shape, int minAxis = 1) -{ - assert(shape.Size() > 1); - return ReshapeTo3D(shape, {1, shape.Size() - 2, 1}, minAxis); -} - #undef FOR_ALL_INT_TYPES } // namespace regor diff --git a/ethosu/regor/compiler/optimiser_utils.cpp b/ethosu/regor/compiler/optimiser_utils.cpp index e48fb1bc59fa5b69ee3abfbafef78dd7a01635d2..f52fa0e1f57ae19eb398b6e43a5a65f8ada14e2b 100644 --- a/ethosu/regor/compiler/optimiser_utils.cpp +++ b/ethosu/regor/compiler/optimiser_utils.cpp @@ -117,19 +117,19 @@ void ReplaceConsumerInput(const Operation *const exemptOperation, std::vectorInputs().pairs() ) + for ( const auto &consInput : consumer->Inputs().pairs() ) { - if ( consInput.second.tensor.get() == tensorToReplace && cons != exemptOperation ) + if ( consInput.second.tensor.get() == tensorToReplace ) { // Do not want to replace the shape. Only the tensor and add writers. // As ConnectInput but do not replace shape. - newTensor->AddReader(cons->shared_from_this()); - auto *consInputConnection = cons->Input(consInput.first); + newTensor->AddReader(consumer); + auto *consInputConnection = consumer->Input(consInput.first); if ( consInputConnection->tensor != newTensor ) { - consInputConnection->tensor->RemoveReader(cons->shared_from_this()); + consInputConnection->tensor->RemoveReader(consumer); consInputConnection->tensor = newTensor; } } diff --git a/ethosu/regor/compiler/raw_writer.hpp b/ethosu/regor/compiler/raw_writer.hpp index b23eae37ebe84523cd31885f11da9895c9c56f6a..2d6979afbaa376b0c50cb0af34d443d996e2b893 100644 --- a/ethosu/regor/compiler/raw_writer.hpp +++ b/ethosu/regor/compiler/raw_writer.hpp @@ -18,6 +18,7 @@ #pragma once +#include "architecture/architecture.hpp" #include "compiler/graph.hpp" #include "compiler/tensor.hpp" diff --git a/ethosu/regor/compiler/scheduler_decompose.cpp b/ethosu/regor/compiler/scheduler_decompose.cpp index 8adcd1c20cb09136975c1a2513df4126ec736b2f..7324e2c6abe1b8250f828383f91612d75b574f85 100644 --- a/ethosu/regor/compiler/scheduler_decompose.cpp +++ b/ethosu/regor/compiler/scheduler_decompose.cpp @@ -20,7 +20,8 @@ #include "common/logging.hpp" -#include "operation_util.hpp" +#include "architecture/architecture_constraints.hpp" +#include "shape_util.hpp" #include #include diff --git a/ethosu/regor/compiler/scheduler_decompose.hpp b/ethosu/regor/compiler/scheduler_decompose.hpp index b11c6952cfa2b02947521399c5e30746798d8346..fee681814673b105a8b00aa5c7e1cb18dce200c8 100644 --- a/ethosu/regor/compiler/scheduler_decompose.hpp +++ b/ethosu/regor/compiler/scheduler_decompose.hpp @@ -19,7 +19,6 @@ #pragma once #include "graph.hpp" -#include "operation.hpp" #include "scheduler_operation.hpp" #include diff --git a/ethosu/regor/compiler/scheduler_operation.hpp b/ethosu/regor/compiler/scheduler_operation.hpp index e64537bc74e373e1194d7519d39a4bf471281c54..33b9731c72708f32a914fe4ed0c0baffb8f14acf 100644 --- a/ethosu/regor/compiler/scheduler_operation.hpp +++ b/ethosu/regor/compiler/scheduler_operation.hpp @@ -20,6 +20,7 @@ #include "common/common.hpp" +#include "architecture/architecture.hpp" #include "common/ordered_map.hpp" #include "kernel.hpp" #include "operation.hpp" diff --git a/ethosu/regor/compiler/scheduler_packing.cpp b/ethosu/regor/compiler/scheduler_packing.cpp index 32b1e82d293271cfd9cf0e2751ee057397aa59f7..631f6040476ecf26415b0c9c9877de8046f0f1aa 100644 --- a/ethosu/regor/compiler/scheduler_packing.cpp +++ b/ethosu/regor/compiler/scheduler_packing.cpp @@ -26,6 +26,7 @@ #include "operation.hpp" #include "scheduler_decompose.hpp" #include "scheduler_operation.hpp" +#include "shape_util.hpp" #include "tensor.hpp" #include @@ -416,7 +417,6 @@ void SchedulerPacking::InitSchedulerConnection( schedConn->slice = {Shape::PadAxes(conn.slice.offset, 4, 0), Shape::PadAxes(conn.slice.shape, 4, 1)}; schedConn->shape = Shape::PadAxes(conn.shape, 4, 1); schedConn->quantization = conn.quantization; - schedConn->transpose = conn.transpose; schedConn->reverse = conn.reverse; schedConn->resamplingMode = ArchResampling::None; } @@ -479,6 +479,7 @@ std::unique_ptr SchedulerPacking::MakeSchedulerOperation(Ope } SchedulerConnection *schedConn = IsOFM(item.first) ? schedOp->AddOutput(item.first) : schedOp->AddInput(item.first); InitSchedulerConnection(schedConn, schedTensor, item.second); + schedConn->transpose = TransposeType::None; } } @@ -490,6 +491,12 @@ std::unique_ptr SchedulerPacking::MakeSchedulerOperation(Ope assert(paddedAxes >= 0); attr->axis += paddedAxes; } + // Update OFM transpose mask if operator has the attribute + else if ( schedOp->HasAttribute() ) + { + auto attr = schedOp->Attribute(); + schedOp->OFM()->transpose = TransposeTypeFromShape(attr->perm); + } // Examine elementwise and set a primary path for cascading. if ( IsBinaryElementwise(op->Type()) ) diff --git a/ethosu/regor/compiler/shape_util.hpp b/ethosu/regor/compiler/shape_util.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4a695147196fea98aa444fe6199501070c329dc3 --- /dev/null +++ b/ethosu/regor/compiler/shape_util.hpp @@ -0,0 +1,74 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the License); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an AS IS BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#pragma once + +#include "common/shape.hpp" +#include "common/transpose_type.hpp" + +namespace regor +{ + +// Convert a permutation shape (up to 8 elements) to a TransposeType +// For example: +// [0, 1, 2, 3] -> 0x0123 ("NHWC") +// [0, 1, 2] -> 0x0123 ("NHWC") +// [0, 1] -> 0x0123 ("NHWC") +// [0] -> 0x0123 ("NHWC") +// [0, 2, 1, 3] -> 0x0213 ("NWHC") +// [1, 0, 2] -> 0x0213 ("NWHC") +inline TransposeType TransposeTypeFromShape(const Shape &perm) +{ + const int n = perm.Size(); + // We can only handle permutation vectors up 8 elements + if ( n > 8 ) throw std::invalid_argument("Permutation shape has more than 8 elements"); + uint32_t mask = perm.ToMask(); + uint32_t offset = 0x76543210 & ~(0xFFFFFFFF >> (4 * (8 - n))); + uint32_t mask8D = mask + offset; + return TransposeType(mask8D); +} + +// Reshape for example (A, B, N, H, W, C) + (3, 2, 1) -> (A*B*N, H*W, C) +inline Shape ReshapeTo3D(const Shape &shape, const Shape &axes, int minAxis = 1) +{ + assert(axes.Size() == 3); + assert(axes[0] + axes[1] + axes[2] == shape.Size()); + int h = std::max(minAxis, shape.AxisProduct(0, axes[0])); + int w = std::max(minAxis, shape.AxisProduct(axes[0], axes[0] + axes[1])); + int c = std::max(minAxis, shape.AxisProduct(axes[0] + axes[1], axes[0] + axes[1] + axes[2])); + return Shape(h, w, c); +} + +// Reshape for example (B, N, H, W, C) + W -> (B*N*H, W, C) +inline Shape ReshapeTo3DAroundAxis(const Shape &shape, int axis, int minAxis = 1) +{ + assert(axis >= 0); + assert(axis < shape.Size()); + int outer = axis; + int inner = shape.Size() - axis - 1; + return ReshapeTo3D(shape, {outer, 1, inner}, minAxis); +} + +// Reshape (B, N, H, W, C) -> (B, N*H*W, C) +inline Shape ReshapeTo3DAroundEdges(const Shape &shape, int minAxis = 1) +{ + assert(shape.Size() > 1); + return ReshapeTo3D(shape, {1, shape.Size() - 2, 1}, minAxis); +} + +} // namespace regor diff --git a/ethosu/regor/compiler/tensor.hpp b/ethosu/regor/compiler/tensor.hpp index f83b00096c99cdd977058f0a70600f48fbd20e80..805697a3d939a6a3d9ac0d5eb967023aee50f964 100644 --- a/ethosu/regor/compiler/tensor.hpp +++ b/ethosu/regor/compiler/tensor.hpp @@ -20,7 +20,6 @@ #include "common/common.hpp" -#include "architecture/architecture.hpp" #include "common/buffer_view.hpp" #include "common/data_type.hpp" #include "common/shape.hpp" @@ -88,6 +87,8 @@ public: void RemoveReaders(); void RemoveWriters(); + bool IsSinglePath() const { return _readers.size() == 1 && _writers.size() == 1; } + std::unique_ptr Clone() const; std::string ToString() const; }; diff --git a/ethosu/regor/compiler/tflite_graph_optimiser.cpp b/ethosu/regor/compiler/tflite_graph_optimiser.cpp index ee079d83773425e4d0cadc9ae424376d28bad6a1..9f646cda6a68464ec80891ae6285835c721d729f 100644 --- a/ethosu/regor/compiler/tflite_graph_optimiser.cpp +++ b/ethosu/regor/compiler/tflite_graph_optimiser.cpp @@ -1555,9 +1555,9 @@ Operation *TFLiteGraphOptimiser::RewriteFullyConnectDynamic(Graph *const, Operat auto matMulOp = std::make_shared(OpType::MatMul); matMulOp->SetRounding(ifm->tensor->Type() == DataType::Int16 ? RoundMode::NATURAL : RoundMode::DBL); - matMulOp->ConnectInput(TensorUsage::IFM0, ifm->tensor).Set(ifmShape).Set(ifm->quantization).Set(ifm->slice).Set(ifm->transpose); - matMulOp->ConnectInput(TensorUsage::IFM1, ifm2Tensor).Set(ifm2Reshaped).Set(ifm2->quantization).Set(ifm2->slice).Set(ifm2->transpose); - matMulOp->ConnectOutput(TensorUsage::OFM, ofm->tensor).Set(ofmShape).Set(ofm->quantization).Set(ofm->slice).Set(ofm->transpose); + matMulOp->ConnectInput(TensorUsage::IFM0, ifm->tensor).Set(ifmShape).Set(ifm->quantization).Set(ifm->slice); + matMulOp->ConnectInput(TensorUsage::IFM1, ifm2Tensor).Set(ifm2Reshaped).Set(ifm2->quantization).Set(ifm2->slice); + matMulOp->ConnectOutput(TensorUsage::OFM, ofm->tensor).Set(ofmShape).Set(ofm->quantization).Set(ofm->slice); RecordOptimisation(operation, matMulOp.get()); returnOp = matMulOp.get(); diff --git a/ethosu/regor/test/test_graphir_optimiser.cpp b/ethosu/regor/test/test_graphir_optimiser.cpp index e5893162cc7205a46848112db84eb65ccca78a4d..3812bb4ba5a403b5f05e2715d05d0c5a6f16db11 100644 --- a/ethosu/regor/test/test_graphir_optimiser.cpp +++ b/ethosu/regor/test/test_graphir_optimiser.cpp @@ -175,3 +175,87 @@ TEST_CASE("test_graphir_optimiser - ReduceSum") REQUIRE(scheduleOps[0]->IFM(0)->quantization.zeroPoints[0] == 0); } } + +TEST_CASE("test_graphir_optimiser - transpose removal") +{ + // Create arch + auto arch = CreateArchDefault(); + std::string err = "noerror"; + arch->CheckConfiguration(err); + REQUIRE(err == "noerror"); + + std::vector> ops; + auto cadd = CreateTensor("CADD", Shape(1, 1, 1, 1), DataType::Int8, 1); + auto input = CreateTensor("INPUT", Shape(1, 10, 5, 4), DataType::Int8); + auto ofm1 = CreateTensor("OFM", Shape(1, 10, 5, 4), DataType::Int8); + auto ofm2 = CreateTensor("OFM", Shape(1, 10, 5, 4), DataType::Int8); + auto output = CreateTensor("OUTPUT", Shape(1, 10, 5, 4), DataType::Int8); + + // Add->Transpose(none)->Add + ops.push_back(CreateOperation(OpType::Add, TensorUsage::IFM, input, TensorUsage::IFM1, cadd, TensorUsage::OFM, ofm1)); + + ops.push_back(CreateOperation(OpType::Transpose, TensorUsage::IFM, ofm1, TensorUsage::OFM, ofm2)); + transpose_attr_t *attr = ops.back()->Attribute(); + attr->perm = Shape(0, 1, 2, 3); + + ops.push_back(CreateOperation(OpType::Add, TensorUsage::IFM, ofm2, TensorUsage::IFM1, cadd, TensorUsage::OFM, output)); + + auto graph = CreateGraph(ops); + + GraphOptimiserOptions options; + auto optimiser = GraphOptimiser::MakeGraphOptimiser(graph->Notation(), arch->Constraints(), options, nullptr); + + optimiser->Process(graph.get()); + + std::vector allOps; + graph->GetAllOperations(allOps); + REQUIRE(allOps.size() == 2); + REQUIRE(allOps.front()->Type() == OpType::Add); + REQUIRE(allOps.back()->Type() == OpType::Add); + REQUIRE(allOps.front()->Output(TensorUsage::OFM)->tensor == allOps.back()->Input(TensorUsage::IFM)->tensor); +} + +TEST_CASE("test_graphir_optimiser - transpose merge") +{ + // Create arch + auto arch = CreateArchDefault(); + std::string err = "noerror"; + arch->CheckConfiguration(err); + REQUIRE(err == "noerror"); + + std::vector> ops; + auto cadd = CreateTensor("CADD", Shape(1, 1, 1, 1), DataType::Int8, 1); + auto input = CreateTensor("INPUT", Shape(1, 10, 4, 5), DataType::Int8); + auto ofm1 = CreateTensor("OFM", Shape(1, 10, 4, 5), DataType::Int8); + auto ofm2 = CreateTensor("OFM", Shape(1, 10, 5, 4), DataType::Int8); + auto ofm3 = CreateTensor("OFM", Shape(1, 10, 4, 5), DataType::Int8); + auto output = CreateTensor("OUTPUT", Shape(1, 10, 4, 5), DataType::Int8); + + // Add->Transpose(there)->Transpose(back)->Add + ops.push_back(CreateOperation(OpType::Add, TensorUsage::IFM, input, TensorUsage::IFM1, cadd, TensorUsage::OFM, ofm1)); + + ops.push_back(CreateOperation(OpType::Transpose, TensorUsage::IFM, ofm1, TensorUsage::OFM, ofm2)); + transpose_attr_t *attr = ops.back()->Attribute(); + attr->perm = Shape(0, 1, 3, 2); + + ops.push_back(CreateOperation(OpType::Transpose, TensorUsage::IFM, ofm2, TensorUsage::OFM, ofm3)); + attr = ops.back()->Attribute(); + attr->perm = Shape(0, 1, 3, 2); + + ops.push_back(CreateOperation(OpType::Add, TensorUsage::IFM, ofm3, TensorUsage::IFM1, cadd, TensorUsage::OFM, output)); + + auto graph = CreateGraph(ops); + + GraphOptimiserOptions options; + auto optimiser = GraphOptimiser::MakeGraphOptimiser(graph->Notation(), arch->Constraints(), options, nullptr); + + optimiser->Process(graph.get()); + + // Result Add->Add + std::vector allOps; + graph->GetAllOperations(allOps); + REQUIRE(allOps.size() == 2); + REQUIRE(allOps.front()->Type() == OpType::Add); + REQUIRE(allOps.back()->Type() == OpType::Add); + REQUIRE(allOps.front()->Output(TensorUsage::OFM)->tensor == allOps.back()->Input(TensorUsage::IFM)->tensor); +} diff --git a/ethosu/regor/tflite/tflite_writer.cpp b/ethosu/regor/tflite/tflite_writer.cpp index 726709165c206bfe593c43af163fe859fcaa1bb6..79005e23113d977261bb129e18cdc65c4f19a86c 100644 --- a/ethosu/regor/tflite/tflite_writer.cpp +++ b/ethosu/regor/tflite/tflite_writer.cpp @@ -20,6 +20,7 @@ #include "common/logging.hpp" +#include "architecture/architecture.hpp" #include "flatbuffer_utils.hpp" #include "tflite_mapping.hpp" diff --git a/ethosu/regor/tflite/tflite_writer.hpp b/ethosu/regor/tflite/tflite_writer.hpp index 5378a341a3927ec6fbe38a626b63e1dc4bcae219..6d74dcba5c0cf1494a17cae303036d0087d57cc7 100644 --- a/ethosu/regor/tflite/tflite_writer.hpp +++ b/ethosu/regor/tflite/tflite_writer.hpp @@ -19,6 +19,7 @@ #pragma once +#include "architecture/architecture.hpp" #include "compiler/graph.hpp" #include "compiler/op_type.hpp" #include "compiler/operation.hpp"