diff --git a/ethosu/regor/tflite/tflite_reader.cpp b/ethosu/regor/tflite/tflite_reader.cpp index 24ec3cb0167d5e57dc23f3efecd052bc7aa4bd07..2491d018daa52d2833f6d0777df3391e28a4ef7b 100644 --- a/ethosu/regor/tflite/tflite_reader.cpp +++ b/ethosu/regor/tflite/tflite_reader.cpp @@ -553,6 +553,14 @@ void TfLiteReader::ParseOperatorOptions(const std::shared_ptr &operat auto biasTens = std::make_shared(weight_tensor->Name() + "_bias", biasType, Shape(1, 1, 1, elems), buf); operation->ConnectInput(TensorUsage::Scales, biasTens); } + if ( options->keep_num_dims() ) + { + auto &ofmShape = operation->Output(TensorUsage::OFM)->shape; + assert(ofmShape[1] <= weight_tensor->StorageShape()[0]); + ofmShape = Shape(ofmShape[0], ofmShape[1]); + auto &ifmShape = operation->Input(TensorUsage::IFM)->shape; + ifmShape = Shape(ofmShape[0], ifmShape[-1]); + } } break;