diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5c9977aa9fbb8c36a933f02ee63f56a56fcbe1da..a65f8bf3c397124283db3e58df6aaebf18d726ca 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -39,7 +39,7 @@ repos: hooks: - id: pytest-pre-commit name: Run pre-commit marked tests - entry: pytest -m pre_commit + entry: pytest -m pre_commit --ignore=tests/test_python_api.py language: system pass_filenames: false always_run: true diff --git a/setup.cfg b/setup.cfg index 5d3e8589257f9a6e72d6b6e1162befc9a43a1c96..f861493b9d24a55f215f8a9ac1a5a1d141beaf4b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -21,4 +21,5 @@ python_files = test_*.py norecursedirs = external markers = cli: marks tests of the command-line interface + python_api: marks tests of the Python API pre_commit: marks tests to be run in pre-commit diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f6559243f2806ee29e0704233936509585a7bd30 --- /dev/null +++ b/tests/README.md @@ -0,0 +1,79 @@ + + +# TOSA Converter for TFLite Test Suite + +This document summarises the test coverage for the tosa-converter-for-tflite project. + +## Requirements to run Tests + +Install the dependencies with: + +```bash +pip install -r requirements.txt +``` + +--- + +## CLI Tests + +The CLI test suite verifies that TFLite models provided as files or through stdin are correctly converted to TOSA MLIR, respecting both output formats (`--bytecode` and `--text`). + +### CLI Test Definition + +CLI tests are defined in `tests/test_cli.hjson` and consumed by `tests/test_cli.py`. + +### Run CLI Tests + +```bash +pytest -v tests/test_cli.py --log-cli-level=DEBUG +``` + +### Adding a CLI Test + +To add a new CLI test, append an object to `tests/test_cli.hjson`. For example, to verify that `--text` without `-o` writes MLIR text to `stdout`: + +```hjson +{ + name: text_to_stdout + description: "Emit MLIR textual form to stdout when --text is given" + args: [ + "--text" + "{model}" + ] + stdin: null + expect_success: true + out_fname: null +} +``` + +--- + +## Python API Tests + +The Python API test suite exercises the `tflite_flatbuffer_to_tosa_mlir` wrapper and tests: + +- **File based** and **stream based** conversions +- **Bytecode** and **text** output format generation options + +### Python API Test Definition + +- **Test script** is `tests/test_python_api.py` + +### Run Python API Tests + +```bash +pytest tests/test_python_api.py -v --log-cli-level=DEBUG +``` + +### Adding a Python API Test + +To add a new text extend `tests/test_python_api.py` with additional tests: + + ```python + def test_some_new_behavior(...): + ... + ``` diff --git a/tests/test_python_api.py b/tests/test_python_api.py new file mode 100644 index 0000000000000000000000000000000000000000..d0300ba5aafa0cbecd61b08a79bb1830c6e1cd7b --- /dev/null +++ b/tests/test_python_api.py @@ -0,0 +1,144 @@ +# +# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +# +# SPDX-License-Identifier: Apache-2.0 +# +""" +Test suite for the `tosa-converter-for-tflite` Python API. + +To run tests with logging: + pytest tests/test_python_api.py -v --log-cli-level=DEBUG + +Tests cover: + - File based & stream based conversion + - Bytecode vs text output via `TosaConverterOutputFormat` + - Error cases: invalid sink types, truncated inputs, invalid format args + - Pre-commit sanity checks +""" +import io +import logging +from pathlib import Path +from typing import Any + +import pytest +from tosa_converter_for_tflite import tflite_flatbuffer_to_tosa_mlir +from tosa_converter_for_tflite import TosaConverterOutputFormat + +# ----------------------------------------------------------------------------- +# Logger configuration +# ----------------------------------------------------------------------------- +logger = logging.getLogger(__name__) + + +# ----------------------------------------------------------------------------- +# Session scoped fixtures +# ----------------------------------------------------------------------------- +@pytest.fixture(scope="session") +def mobilenet_tflite(tmp_path_factory) -> Path: + """Build a MobileNetV3Small TFLite model for testing.""" + # Skip if TensorFlow isn't available + tf = pytest.importorskip( + "tensorflow", reason="TensorFlow is required for TFLite conversion" + ) + # 1) Load the pretrained Keras model + keras_model = tf.keras.applications.MobileNetV3Small( + weights="imagenet", input_shape=(224, 224, 3) + ) + # 2) Convert it to TFLite + converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) + tflite_bytes = converter.convert() + + # 3) Write it out to a temp .tflite file + model_dir = tmp_path_factory.mktemp("tflite_model") + model_path = model_dir / "mobilenet_v3_small_1.0_224.tflite" + model_path.write_bytes(tflite_bytes) + logger.info("TFLite model written to %s", model_path) + return model_path + + +# ----------------------------------------------------------------------------- +# Tests +# ----------------------------------------------------------------------------- +@pytest.mark.python_api +def test_precommit_sanity() -> None: + """Sanity check: API should be callable.""" + logger.info("Running python API callable sanity checks") + assert callable(tflite_flatbuffer_to_tosa_mlir), "API function missing" + + +@pytest.mark.python_api +def test_bytecode_filesystem(tmp_path: Path, mobilenet_tflite: Path) -> None: + """Ensure that file based bytecode conversion creates a non empty .tosa.mlirbc file.""" + out_file = tmp_path / f"{__name__}.tosa.mlirbc" + logger.info(f"Testing filesystem bytecode output for '{__name__}'") + tflite_flatbuffer_to_tosa_mlir(str(mobilenet_tflite), str(out_file)) + assert out_file.exists(), "Output file not created" + assert out_file.stat().st_size > 0, "Output file is empty" + logger.debug(f"Created bytecode file {out_file} ({out_file.stat().st_size} bytes)") + + +@pytest.mark.python_api +def test_bytecode_stream_vs_filesystem(tmp_path: Path, mobilenet_tflite: Path) -> None: + """Verify that in memory bytecode == file based bytecode and flush() was called.""" + buf = io.BytesIO() + called = False + + def _flush() -> None: + nonlocal called + called = True + + buf.flush = _flush + + tflite_flatbuffer_to_tosa_mlir( + str(mobilenet_tflite), buf, TosaConverterOutputFormat.Bytecode + ) + data_stream = buf.getvalue() + assert data_stream, "In memory bytecode is empty" + assert called, "flush() was not called on the stream" + + fs_path = tmp_path / "sample_fs.tosa.mlirbc" + tflite_flatbuffer_to_tosa_mlir( + str(mobilenet_tflite), str(fs_path), TosaConverterOutputFormat.Bytecode + ) + data_file = fs_path.read_bytes() + + assert data_stream == data_file, "Bytecode mismatch between stream and file" + + +@pytest.mark.python_api +def test_missing_input_file(tmp_path: Path): + """Missing Input file.""" + bad = tmp_path / "does_not_exist.tflite" + with pytest.raises(RuntimeError): + tflite_flatbuffer_to_tosa_mlir(str(bad), "missing_input.tosa") + + +@pytest.mark.python_api +@pytest.mark.parametrize("sink", [123, None, 5.6]) +def test_invalid_sink(mobilenet_tflite: Path, sink: Any) -> None: + """Passing a non path, non stream sink should raise TypeError.""" + with pytest.raises(TypeError): + tflite_flatbuffer_to_tosa_mlir(str(mobilenet_tflite), sink) + + +@pytest.mark.python_api +def test_truncated_tflite(tmp_path: Path, mobilenet_tflite: Path) -> None: + """A truncated .tflite input should cause a RuntimeError.""" + data = mobilenet_tflite.read_bytes() + trunc = data[: len(data) // 2] + trunc_path = tmp_path / "trunc.tflite" + trunc_path.write_bytes(trunc) + with pytest.raises(RuntimeError): + tflite_flatbuffer_to_tosa_mlir( + str(trunc_path), str(tmp_path / "truncated_input.tosa.mlirbc") + ) + + +@pytest.mark.python_api +@pytest.mark.parametrize("fmt", [123, "bogus", None]) +def test_invalid_format(tmp_path: Path, mobilenet_tflite: Path, fmt: Any) -> None: + """Passing an invalid format argument should raise TypeError or ValueError.""" + with pytest.raises((TypeError, ValueError)): + tflite_flatbuffer_to_tosa_mlir( + str(mobilenet_tflite), str(tmp_path / "invalid_format.tosa.mlirbc"), fmt + )