diff --git a/ethosu/regor/tflite/tflite_reader.cpp b/ethosu/regor/tflite/tflite_reader.cpp index 50d963243b7025f25d156ac86be388a0503a237d..e5ad5f345d597c16e5bf2d2046f9d1102a8719ea 100644 --- a/ethosu/regor/tflite/tflite_reader.cpp +++ b/ethosu/regor/tflite/tflite_reader.cpp @@ -275,6 +275,23 @@ void TfLiteReader::LoadGraphs(const uint8_t *input, const tflite::Model *model, for ( const int tensor_index : *tflite_outputs ) { const auto &ofm = tensors.at(tensor_index); + if ( !ofm->StorageShape() ) + { + // Try to figure out the OFM shape if the OFM shape is unknown + if ( IsUnaryElementwise(op_type) || op_type == OpType::Quantize ) + { + auto ifm = operation->IFM(0); + assert(ifm); + ofm->SetStorageShape(ifm->StorageShape()); + } + else if ( IsBinaryElementwise(op_type) ) + { + auto ifm0 = operation->IFM(0); + auto ifm1 = operation->IFM(1); + assert(ifm0 && ifm1); + ofm->SetStorageShape(Shape::Max(ifm0->StorageShape(), ifm1->StorageShape())); + } + } shapelessTensors = shapelessTensors || !ofm->StorageShape(); assert(tensorQuantization.count(ofm->Uid()) > 0); operation->ConnectOutput(MakeTensorUsage(TensorUsage::OFM, ofm_count++), ofm).Set(tensorQuantization[ofm->Uid()]);