diff --git a/ethosu/regor/compiler/graphir_optimiser.cpp b/ethosu/regor/compiler/graphir_optimiser.cpp index ab6337653b836c175f31d0436af9706c57114264..893efbfb3a0274e125cff32fd7937010c5d221a3 100644 --- a/ethosu/regor/compiler/graphir_optimiser.cpp +++ b/ethosu/regor/compiler/graphir_optimiser.cpp @@ -1808,26 +1808,24 @@ Operation *GraphIrOptimiser::RewriteDepthwise(Graph *const graph, Operation *con { const auto ifm = operation->Input(TensorUsage::IFM0); const auto ofm = operation->Output(TensorUsage::OFM); - const auto weights = operation->Input(TensorUsage::Weights); - const auto shape = weights->tensor->StorageShape(); - const auto &axisOrder = weights->tensor->AxisOrder(); const auto multiplier = operation->Kernel()->DepthMultiplier(); if ( ifm && (ifm->shape.Depth() == 1) && (multiplier != 1) && ofm && (ofm->shape.Depth() == multiplier) ) { auto newOp = std::make_shared(OpType::Conv2D); - RoundMode ofmRound = ofm->rounding; - auto kernel = std::make_unique(operation->Kernel()->Size(), operation->Kernel()->Stride(), - operation->Kernel()->Dilation(), 1, operation->Kernel()->Padding()); + auto kernel = std::make_unique(operation->Kernel()->WithDepthMultiplier(1)); newOp->SetKernel(std::move(kernel)); - if ( axisOrder == AxisOrder::HWCM ) + const auto weights = operation->Input(TensorUsage::Weights); + if ( weights->tensor->AxisOrder() == AxisOrder::HWCM ) { + const auto &shape = weights->tensor->StorageShape(); weights->tensor->Reshape(Shape(1, shape[0], shape[1], shape[3])); weights->tensor->SetAxisOrder(AxisOrder::IHWO); + weights->shape = weights->tensor->StorageShape(); } ReplaceOperation(operation, newOp.get()); - newOp->Output(TensorUsage::OFM)->Set(ofmRound); + newOp->Output(TensorUsage::OFM)->Set(ofm->rounding); returnOp = newOp.get(); RecordOptimisation(operation, returnOp); }