diff --git a/ethosu/regor/compiler/tflite_graph_optimiser.cpp b/ethosu/regor/compiler/tflite_graph_optimiser.cpp index de3d63649248d7fb2ad8056a2acf6a9a3d8b2504..b0758ae20cd740c71198c1f02418375ff77fd907 100644 --- a/ethosu/regor/compiler/tflite_graph_optimiser.cpp +++ b/ethosu/regor/compiler/tflite_graph_optimiser.cpp @@ -1456,7 +1456,6 @@ Operation *TFLiteGraphOptimiser::RewriteFullyConnectDynamic(Graph *const, Operat auto ofmShape = Shape::PadAxes(ofm->shape, 4, 1); auto ifmShape = Shape::PadAxes(ifm->shape, 4, 1); - assert(ifm2->tensor->AxisOrder() == AxisOrder::OHWI); assert(ifm2->shape.Size() == 4 && "FullyConnected with non-4D weights"); assert(ifm2->shape.ElementsWH() == 1 && "FullyConnected with non-unit W*H weight-shape"); diff --git a/ethosu/regor/test/test_tflite_supported_operators.cpp b/ethosu/regor/test/test_tflite_supported_operators.cpp index 7aeff463986548427ace5f5db211e17f229fcac2..056abaca58c39528412a1bb6f5cd84ddcd9f0580 100644 --- a/ethosu/regor/test/test_tflite_supported_operators.cpp +++ b/ethosu/regor/test/test_tflite_supported_operators.cpp @@ -124,7 +124,9 @@ TEST_CASE("Supported operators Common") weights->SetAxisOrder(AxisOrder::OHWI); op->ConnectInput(TensorUsage::Weights, weights).Set(Quantization::Unit()); REQUIRE(supportedOps->Check(op.get()) == true); - op->Input(TensorUsage::Weights)->tensor->Reshape(Shape(2, 2, 1, 2)); + // reshape and reconnect tensor + weights->Reshape(Shape(2, 2, 1, 2)); + op->ConnectInput(TensorUsage::Weights, weights).Set(Quantization::Unit()); REQUIRE(supportedOps->Check(op.get()) == false); op->Disconnect(); } diff --git a/ethosu/regor/tflite/tflite_supported_operators.cpp b/ethosu/regor/tflite/tflite_supported_operators.cpp index 4890bf7aadad5350e8df09f8e468fd1bf6ffc4dd..079d9fda96dbae4e6d6067198f43802396055fba 100644 --- a/ethosu/regor/tflite/tflite_supported_operators.cpp +++ b/ethosu/regor/tflite/tflite_supported_operators.cpp @@ -199,7 +199,7 @@ bool TfLiteSupportedOperators::ConstraintFCWeightShape(const Operation *op) auto weights = op->Input(TensorUsage::Weights); assert(weights); assert(weights->tensor); - const auto &shape = weights->tensor->StorageShape(); + const auto &shape = weights->shape; // Total elements must be equal to first-dim * last-dim if ( shape.Size() < 2 || (shape.Elements() != (shape[0] * shape[-1])) ) {