diff --git a/ethosu/regor/architecture/ethosu55/ethos_u55.cpp b/ethosu/regor/architecture/ethosu55/ethos_u55.cpp index d03ff1a28e882c0192e500df6458b6cf8169f82f..3fea9e9f467f6e4dd238ac2d014a977c33dc9b3b 100644 --- a/ethosu/regor/architecture/ethosu55/ethos_u55.cpp +++ b/ethosu/regor/architecture/ethosu55/ethos_u55.cpp @@ -180,6 +180,20 @@ void ArchEthosU55::ApplyConfig(const AcceleratorConfig *cfg) _rcsGenerator = std::make_unique(this); } +static Shape MatMulDependencyFit(const Shape &shape, int minSize, const Shape &blockLimit) +{ + // Attempt to fit multiple blocks in W/H to reduce block + // dependency stalls + int axis = (shape.Height() > blockLimit.Height()) ? -3 : -2; + if ( shape[axis] <= blockLimit[axis] ) + { + for ( int divider = 3; divider > 1; divider-- ) + { + if ( shape[axis] >= (minSize * divider) ) return shape.With(axis, DivRoundUp(shape[axis], divider)); + } + } + return shape; +} std::unique_ptr ArchEthosU55::GetOpConfig(OpType opType, const ArchitectureConfigQuery &query) { @@ -189,17 +203,19 @@ std::unique_ptr ArchEthosU55::GetOpConfig(OpType opType, c ArchitectureConfigQuery tmpQuery = query; Kernel unitKernel = Kernel::UnitKernel(); int batches = query.ofmShape.Height(); + // Block configuration for the Elementwise Mul tmpQuery.kernel = &unitKernel; tmpQuery.ifmBits = query.ifmBits; tmpQuery.ifmShape[1] = Shape(1, batches, 1, query.ifmShape[1].Depth()); - tmpQuery.ofmShape = query.ifmShape[0]; + tmpQuery.ifmShape[0] = MatMulDependencyFit(query.ifmShape[0], 4, _ofmBlockMax); + tmpQuery.ofmShape = tmpQuery.ifmShape[0]; tmpQuery.ofmFormat = TensorFormat::NHWC; tmpQuery.ofmBits = 32; tmpQuery.transpose = TransposeType::None; auto mulConfig = FindBlockConfig(OpType::Mul, tmpQuery); // Block configuration for the Reduced Sum - tmpQuery.ofmShape = Shape(1, batches, query.ifmShape[0].Width(), 1); + tmpQuery.ofmShape = MatMulDependencyFit(Shape(1, batches, query.ifmShape[0].Width(), 1), 4, _ofmBlockMax); tmpQuery.ofmBits = query.ofmBits; tmpQuery.ofmFormat = query.ofmFormat; auto reduceConfig = FindBlockConfig(OpType::ReduceSum, tmpQuery);