From 2ddd8e1b97fa4ffb7832ad9430e7dd5b35bc58ed Mon Sep 17 00:00:00 2001 From: Deeptanshu Sekhri Date: Fri, 4 Jul 2025 11:14:28 +0100 Subject: [PATCH] feat(cli): introduce tosa-converter-for-tflite CLI, tests, and docs - Add `tosa_converter_for_tflite/cli.py` implementing the new CLI entrypoint - Register CLI script in `setup.py` under `tosa-converter-for-tflite` - Add new test definitions in `tests/test_cli.hjson` and runner in `tests/test_cli.py` - Add tests README in `tests/README.md` - Configure pytest in `setup.cfg` (testpaths, markers, python_files, norecursedirs) - Update `.pre-commit-config.yaml` with `pytest-pre-commit` hook for `pre_commit` marked tests - Extend project `README.md` with CLI usage and streaming examples Signed-off-by: Deeptanshu Sekhri --- .pre-commit-config.yaml | 9 + README.md | 32 +++ setup.cfg | 15 +- setup.py | 6 + tests/test_cli.hjson | 292 +++++++++++++++++++++++ tests/test_cli.py | 384 +++++++++++++++++++++++++++++++ tosa_converter_for_tflite/cli.py | 185 +++++++++++++++ 7 files changed, 922 insertions(+), 1 deletion(-) create mode 100644 tests/test_cli.hjson create mode 100644 tests/test_cli.py create mode 100644 tosa_converter_for_tflite/cli.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ee7e06b..5c9977a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -34,3 +34,12 @@ repos: entry: clang-format types: ["c++"] args: ["-i"] + +- repo: local + hooks: + - id: pytest-pre-commit + name: Run pre-commit marked tests + entry: pytest -m pre_commit + language: system + pass_filenames: false + always_run: true diff --git a/README.md b/README.md index d33cf1a..7f5b570 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,38 @@ from tosa_converter_for_tflite import tflite_flatbuffer_to_tosa_mlir_bytecode tflite_flatbuffer_to_tosa_mlir_bytecode("model_input.tflite", "model_output.tosa.mlirbc") ``` +## Command-line Interface (CLI) + +The `tosa-converter-for-tflite` CLI provides a simple way to convert `.tflite` models +into TOSA MLIR bytecode. It supports both file-based and stream-based workflows +for easy integration into toolchains or pipelines. + +### Usage Examples + +**Convert a `.tflite` file to a TOSA bytecode file:** + +```bash +tosa-converter-for-tflite model.tflite -o out.tosa.mlirbc +``` + +**Read the model from `stdin` and write to a file:** + +```bash +cat model.tflite | tosa-converter-for-tflite -o out.tosa.mlirbc +``` + +**Read from a file and write to `stdout`:** + +```bash +tosa-converter-for-tflite model.tflite > out.tosa.mlirbc +``` + +**Fully stream the model through the CLI and into another tool:** + +```bash +cat model.tflite | tosa-converter-for-tflite | +``` + ## Supported Platforms This tool supports Linux and Windows(r) on x86_64 architectures. diff --git a/setup.cfg b/setup.cfg index 9bbd89d..5d3e858 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,4 +8,17 @@ max-line-length = 88 extend-ignore = E203, W503, E501, D213, E266 select = B,E,F,W,T4 -exclude = .eggs, build, third_party \ No newline at end of file +exclude = .eggs, build, third_party + +[tool:pytest] +# Only look for tests under tese directories +testpaths = tests + +# Only consider files named test_*.py +python_files = test_*.py + +# Also skip any 'external' directories +norecursedirs = external +markers = + cli: marks tests of the command-line interface + pre_commit: marks tests to be run in pre-commit diff --git a/setup.py b/setup.py index 2eafc7f..4cdd98c 100644 --- a/setup.py +++ b/setup.py @@ -239,4 +239,10 @@ setuptools.setup( bazel_shared_lib_output="bazel-bin/tosa_converter_for_tflite/_tosa_converter_for_tflite_wrapper.so", ), ], + packages=setuptools.find_packages(), + entry_points={ + "console_scripts": [ + "tosa-converter-for-tflite=tosa_converter_for_tflite.cli:main" + ], + }, ) diff --git a/tests/test_cli.hjson b/tests/test_cli.hjson new file mode 100644 index 0000000..a9850d6 --- /dev/null +++ b/tests/test_cli.hjson @@ -0,0 +1,292 @@ +# +# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +# +# SPDX-License-Identifier: Apache-2.0 +# + +# Test cases for tosa-converter-for-tflite CLI +# ----------------------------------------------------------------------------- +# Test‐case schema reference: +# name (string) : unique identifier for the test case +# description (string) : human-readable description +# args (array) : list of CLI args, use {model} and {out} +# stdin (string|null) : null, "model", or one of ["empty","truncated128","random512"] +# expect_success (bool) : whether the CLI invocation should succeed +# out_fname (string|null) : expected output filename relative to {out} +# Optional keys: +# error_pattern (string) : regex to match stderr on failure +# timeout (int) : timeout in seconds for this test +# ----------------------------------------------------------------------------- + +# Test cases for tosa-converter-for-tflite CLI + +[ + # ───────────────────────────────────────────────────────────────────────────── + # Expected successful conversions + # ───────────────────────────────────────────────────────────────────────────── + + { + name: file_input_to_file_output + description: "Convert from a file input to the specified file output" + args: [ + "{model}" + "--tosa-output-path" + "{out}/out_file1.tosa.mlirbc" + ] + stdin: null + expect_success: true + out_fname: "out_file1.tosa.mlirbc" + }, + + { + name: file_input_pos_last + description: "Convert when the positional input appears last in the CLI invocation" + args: [ + "--tosa-output-path" + "{out}/out_file2.tosa.mlirbc" + "{model}" + ] + stdin: null + expect_success: true + out_fname: "out_file2.tosa.mlirbc" + }, + + { + name: stdin_input_to_file_output + description: "Read the model from stdin (dash) and write to file output" + args: [ + "-" + "--tosa-output-path" + "{out}/out_stdin1.tosa.mlirbc" + ] + stdin: model + expect_success: true + out_fname: "out_stdin1.tosa.mlirbc" + }, + + { + name: stdin_input_dash_omitted + description: "Read model from stdin without explicit dash flag and write to file" + args: [ + "--tosa-output-path" + "{out}/out_stdin2.tosa.mlirbc" + ] + stdin: model + expect_success: true + out_fname: "out_stdin2.tosa.mlirbc" + }, + + { + name: verbose_logging + description: "Enable verbose logging during conversion" + args: [ + "{model}" + "--tosa-output-path" + "{out}/out_verbose.tosa.mlirbc" + "--verbose" + ] + stdin: null + expect_success: true + out_fname: "out_verbose.tosa.mlirbc" + }, + + { + name: verbose_and_time + description: "Enable verbose logging and measure execution time" + args: [ + "{model}" + "--tosa-output-path" + "{out}/out_time.tosa.mlirbc" + "--verbose" + "--measure-time" + ] + stdin: null + expect_success: true + out_fname: "out_time.tosa.mlirbc" + }, + + { + name: version_default + description: "Omit TOSA version to use the default value" + args: [ + "{model}" + "--tosa-output-path" + "{out}/version_default.tosa.mlirbc" + ] + stdin: null + expect_success: true + out_fname: "version_default.tosa.mlirbc" + }, + + { + name: shorthand_o + description: "Use the '-o' shorthand instead of '--tosa-output-path'" + args: [ + "{model}" + "-o" + "{out}/shorthand_o.tosa.mlirbc" + ] + stdin: null + expect_success: true + out_fname: "shorthand_o.tosa.mlirbc" + }, + + { + name: shorthand_stdin + description: "Use '-o' shorthand when reading model from stdin" + args: [ + "-" + "-o" + "{out}/shorthand_stdin.tosa.mlirbc" + ] + stdin: model + expect_success: true + out_fname: "shorthand_stdin.tosa.mlirbc" + }, + + { + name: shorthand_all + description: "Combine all shorthands with verbose and time flags" + args: [ + "{model}" + "-o" + "{out}/shorthand_all.tosa.mlirbc" + "--verbose" + "--measure-time" + ] + stdin: null + expect_success: true + out_fname: "shorthand_all.tosa.mlirbc" + }, + + { + name: file_input_to_stdout, + description: "Convert a file input and write TOSA bytecode to stdout", + args: [ + "{model}" + ], + stdin: null, + expect_success: true, + out_fname: null + }, + + # ───────────────────────────────────────────────────────────────────────────── + # Expected failure scenarios + # ───────────────────────────────────────────────────────────────────────────── + + { + name: missing_input_path + description: "Fail when no input path is provided" + args: [ + "--tosa-output-path" + "{out}/missing_input.tosa.mlirbc" + ] + stdin: null + expect_success: false + out_fname: null + }, + + { + name: missing_output_arg + description: "no output path: writes to stdout" + args: [ + "{model}" + ] + stdin: null + expect_success: true + out_fname: null + }, + + { + name: invalid_model_path + description: "Fail on an invalid or non existent model file path" + args: [ + "{out}/nonexistent.tflite" + "--tosa-output-path" + "{out}/invalid_path.tosa.mlirbc" + ] + stdin: null + expect_success: false + out_fname: null + }, + + { + name: empty_stdin + description: "Fail when stdin is empty" + args: [ + "-" + "--tosa-output-path" + "{out}/empty_stdin.tosa.mlirbc" + ] + stdin: empty + expect_success: false + out_fname: null + }, + + { + name: corrupt_input + description: "Fail on truncated or corrupt TFLite data via stdin" + args: [ + "-" + "--tosa-output-path" + "{out}/corrupt.tosa.mlirbc" + ] + stdin: truncated128 + expect_success: false + out_fname: null + }, + + { + name: unknown_argument + description: "Fail when an unknown CLI argument is passed" + args: [ + "{model}" + "--unknown-flag" + "foo" + "--tosa-output-path" + "{out}/unknown_arg.tosa.mlirbc" + ] + stdin: null + expect_success: false + out_fname: null + }, + + { + name: missing_model_ext + description: "Fail when model path lacks the '.tflite' extension" + args: [ + "{out}/somefile" + "--tosa-output-path" + "{out}/wrong_ext.tosa.mlirbc" + ] + stdin: null + expect_success: false + out_fname: null + }, + + { + name: garbage_stdin + description: "Fail when random bytes are provided via stdin" + args: [ + "-" + "--tosa-output-path" + "{out}/garbage.tosa.mlirbc" + ] + stdin: random512 + expect_success: false + out_fname: null + }, + + { + name: shorthand_missing_o + description: "Fail when '-o' shorthand is provided without a value" + args: [ + "{model}" + "-o" + ] + stdin: null + expect_success: false + out_fname: null + }, + +] diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..1352ef7 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,384 @@ +# +# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +# +# SPDX-License-Identifier: Apache-2.0 +# +""" +Test suite for the `tosa-converter-for-tflite` CLI. + +To run with debug logs: + pytest tests/test_cli.py -v --log-cli-level=DEBUG + +To add a test: + Edit `test_cli.hjson` and define a new entry with these fields: + + - name (str): unique test ID + - description (str): what the test does + - args (List[str]): CLI args, use `{model}` and `{out}` + - stdin (Optional[str]): one of null, "model", or a special key + - expect_success (bool) + - out_fname (Optional[str]) + + Optional: + - error_pattern (str): regex to match stderr + - timeout (int): seconds + This test suite will pick it up automatically once added to the JSON. + +Troubleshoot: + - Increase `timeout` if conversions take longer. + - Ensure `tosa-converter-for-tflite` is on your PATH. +""" +import logging +import random +import re +import shutil +import stat +import subprocess +import sys +from pathlib import Path +from typing import Any +from typing import Dict +from typing import List +from typing import Optional + +import hjson +import pytest + +# ─────────────────────────────────────────────────────────────────────────────── +# Logging configuration +# ─────────────────────────────────────────────────────────────────────────────── +# get the PyTest configured logger +logger = logging.getLogger(__name__) + +# ─────────────────────────────────────────────────────────────────────────────── +# Configuration constants +# ─────────────────────────────────────────────────────────────────────────────── +CLI_BINARY = "tosa-converter-for-tflite" +DEFAULT_TIMEOUT = 60 # Timeout for subprocess.run in seconds +JSON_PATH = Path(__file__).parent / "test_cli.hjson" # external test definitions + +# ─────────────────────────────────────────────────────────────────────────────── +# Load & validate test cases from JSON +# ─────────────────────────────────────────────────────────────────────────────── +with JSON_PATH.open() as f: + CLI_TEST_CASES: List[Dict[str, Any]] = hjson.load(f) + +logger.info("Loaded %d test cases from %s", len(CLI_TEST_CASES), JSON_PATH) + + +def validate_cases(cases: List[Dict[str, Any]]) -> None: + """Ensure each JSON entry has the correct schema.""" + logger.debug("Validating %d test cases", len(cases)) + required = {"name", "description", "args", "stdin", "expect_success", "out_fname"} + optional = {"error_pattern", "timeout"} + for idx, case in enumerate(cases): + keys = set(case.keys()) + missing = required - keys + if missing: + logger.error("Case #%d missing keys: %s", idx, missing) + raise ValueError(f"[cli_test_cases.json][#{idx}] missing keys: {missing}") + extras = keys - required - optional + if extras: + logger.error("Case %r has unknown keys: %s", case.get("name"), extras) + raise ValueError(f"[cli_test_cases.json][#{idx}] unknown keys: {extras}") + if not isinstance(case["name"], str): + raise TypeError(f"[#{idx}] name must be str") + if not isinstance(case["description"], str): + raise TypeError(f"[{case['name']!r}] description must be str") + if not ( + isinstance(case["args"], list) + and all(isinstance(a, str) for a in case["args"]) + ): + raise TypeError(f"[{case['name']!r}] args must be List[str]") + if case["stdin"] is not None and not isinstance(case["stdin"], str): + raise TypeError(f"[{case['name']!r}] stdin must be null or str") + if not isinstance(case["expect_success"], bool): + raise TypeError(f"[{case['name']!r}] expect_success must be bool") + if case["out_fname"] is not None and not isinstance(case["out_fname"], str): + raise TypeError(f"[{case['name']!r}] out_fname must be null or str") + if "error_pattern" in case and not isinstance(case["error_pattern"], str): + raise TypeError(f"[{case['name']!r}] error_pattern must be str") + if "timeout" in case: + t = case["timeout"] + if not isinstance(t, int) or t <= 0: + raise TypeError(f"[{case['name']!r}] timeout must be positive int") + + +validate_cases(CLI_TEST_CASES) + +# ─────────────────────────────────────────────────────────────────────────────── +# Helper functions and handlers +# ─────────────────────────────────────────────────────────────────────────────── +# Handlers for special stdin test cases +_STDIN_HANDLERS = { + "empty": lambda _: b"", + "truncated128": lambda p: Path(p).read_bytes()[:128], + "random512": lambda _: random.randbytes(512), +} + + +def build_stdin(kind: Optional[str], model_path: Path) -> Optional[bytes]: + """ + Build the required stdin bytes for a given test kind. + Returns None if no stdin should be provided. + """ + logger.debug("Building stdin payload: kind=%r, model_path=%s", kind, model_path) + if kind == "model": + # Read the entire model file + return model_path.read_bytes() + if kind in _STDIN_HANDLERS: + # Delegate to handler + return _STDIN_HANDLERS[kind](str(model_path)) + return None + + +def set_readonly(path: Path) -> None: + """Mark a path as read-only for the duration of a test.""" + logger.debug("Setting %s to read-only", path) + if sys.platform.startswith("win"): + path.chmod(stat.S_IREAD) + else: + path.chmod(0o500) + + +def restore_permissions(path: Path) -> None: + """Restore write and execute permissions on a path after testing.""" + logger.debug("Restoring permissions for %s", path) + if sys.platform.startswith("win"): + path.chmod(stat.S_IWRITE | stat.S_IREAD | stat.S_IEXEC) + else: + path.chmod(0o700) + + +# ─────────────────────────────────────────────────────────────────────────────── +# Test-case abstraction Class +# ─────────────────────────────────────────────────────────────────────────────── +class CLITestCase: + """ + Encapsulates a single CLI test case by providing: + - command construction + - stdin selection + - success/failure assertions + - output validation or error-pattern matching + """ + + def __init__(self, **kwargs: Any): + self.name: str = kwargs["name"] + self.description: str = kwargs["description"] + self.args: List[str] = kwargs["args"] + self.stdin: Optional[str] = kwargs["stdin"] + self.expect_success: bool = kwargs["expect_success"] + self.out_fname: Optional[str] = kwargs["out_fname"] + self.error_pattern: str = kwargs.get("error_pattern", r"error") + self.timeout: int = kwargs.get("timeout", DEFAULT_TIMEOUT) + + def build_cmd(self, cli_path: str, model_path: Path, out_dir: Path) -> List[str]: + """Format CLI arguments with actual paths and prepend the CLI.""" + formatted_args = [a.format(model=model_path, out=out_dir) for a in self.args] + return [cli_path] + formatted_args + + def run_and_assert(self, cli_path: str, model_path: Path, out_dir: Path) -> None: + """ + Execute the CLI command and assert based on expected outcome. + On success: exit code should be 0 and nonempty output file. + On failure: nonzero exit and stderr matches error_pattern. + """ + cmd = self.build_cmd(cli_path, model_path, out_dir) + stdin_bytes = build_stdin(self.stdin, model_path) + logger.info("Running test %r: %s", self.name, cmd) + result = subprocess.run( + cmd, + input=stdin_bytes, + capture_output=True, + text=False, + timeout=self.timeout, + check=False, + ) + stdout_bytes = result.stdout or b"" + stderr_bytes = result.stderr or b"" + logger.debug( + "Test %r result: returncode=%d, stdout=%d bytes, stderr=%d bytes", + self.name, + result.returncode, + len(stdout_bytes), + len(stderr_bytes), + ) + stderr = stderr_bytes.decode("utf-8", "replace") + + if self.expect_success: + # Success: ensure zero exit status and valid output + if result.returncode != 0: + logger.error("Test %r failed with rc=%d", self.name, result.returncode) + pytest.fail(f"{self.name} failed (rc={result.returncode}):\n{stderr}") + if self.out_fname is not None: + out_file = out_dir / self.out_fname + assert ( + out_file.exists() and out_file.stat().st_size > 0 + ), f"{self.name} did not produce output at {out_file}" + else: + # Failure: nonzero exit status and stderr should match error pattern + assert result.returncode != 0, ( + f"{self.name} unexpectedly succeeded; stdout:\n" + f"{stdout_bytes.decode('utf-8', 'replace')}" + ) + assert re.search(self.error_pattern, stderr.lower()), ( + f"{self.name} stderr did not match pattern {self.error_pattern!r}.\n" + f"Stderr:\n{stderr}" + ) + + +# Instantiate CLITestCase objects +CLI_CASES = [CLITestCase(**case) for case in CLI_TEST_CASES] + + +# ─────────────────────────────────────────────────────────────────────────────── +# Fixtures +# ─────────────────────────────────────────────────────────────────────────────── +@pytest.fixture(scope="session") +def cli_path() -> str: + """Locate the CLI or skip the tests if not found.""" + path = shutil.which(CLI_BINARY) + logger.info("Located CLI binary at %s", path) + if not path: + pytest.skip(f"{CLI_BINARY!r} not found in PATH") + return path + + +@pytest.fixture(scope="session") +def mobilenet_tflite(tmp_path_factory) -> Path: + """Build a MobileNetV3Small TFLite model for testing.""" + # Skip if TensorFlow isn't available + tf = pytest.importorskip( + "tensorflow", reason="TensorFlow is required for TFLite conversion" + ) + # 1) Load the pretrained Keras model + keras_model = tf.keras.applications.MobileNetV3Small( + weights="imagenet", input_shape=(224, 224, 3) + ) + # 2) Convert it to TFLite + converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) + tflite_bytes = converter.convert() + + # 3) Write it out to a temp .tflite file + model_dir = tmp_path_factory.mktemp("tflite_model") + model_path = model_dir / "mobilenet_v3_small_1.0_224.tflite" + model_path.write_bytes(tflite_bytes) + logger.info("TFLite model written to %s", model_path) + return model_path + + +@pytest.fixture(scope="session") +def out_dir(tmp_path_factory) -> Path: + """ + Create a single output directory for the entire test session, + and attach a FileHandler that logs to out_dir/'test_cli.log'. + and at end, remove the directory if it's empty. + """ + directory = tmp_path_factory.mktemp("cli_out") + logger.info("Session-wide output directory: %s", directory) + + # Set up file logging once + log_path = directory / "test_cli.log" + file_handler = logging.FileHandler(log_path, mode="w") + file_handler.setLevel(logging.DEBUG) + file_handler.setFormatter( + logging.Formatter("%(asctime)s %(levelname)s [%(name)s] %(message)s") + ) + logger.addHandler(file_handler) + + yield directory + + # Clean up logger handler after all tests have run + logger.removeHandler(file_handler) + + +# ─────────────────────────────────────────────────────────────────────────────── +# Parametrized CLI tests +# ─────────────────────────────────────────────────────────────────────────────── +@pytest.mark.cli +@pytest.mark.parametrize("case", CLI_CASES, ids=[c.name for c in CLI_CASES]) +def test_cli(case: CLITestCase, cli_path: str, mobilenet_tflite: Path, out_dir: Path): + """Run all CLI test cases.""" + case.run_and_assert(cli_path, mobilenet_tflite, out_dir) + + +# ─────────────────────────────────────────────────────────────────────────────── +# Special-case tests (help, filesystem edge cases) +# ─────────────────────────────────────────────────────────────────────────────── +@pytest.mark.cli +@pytest.mark.pre_commit +def test_help_flag_writes_usage(cli_path: str, out_dir: Path) -> None: + """--help should print usage and exit code 0.""" + result = subprocess.run( + [cli_path, "--help"], capture_output=True, text=True, timeout=DEFAULT_TIMEOUT + ) + help_txt = out_dir / "help.txt" + help_txt.write_text(result.stdout) + assert result.returncode == 0 + assert "help" in help_txt.read_text().lower() + + +@pytest.mark.cli +def test_output_to_unwritable_directory( + cli_path: str, mobilenet_tflite: Path, out_dir: Path +) -> None: + """Attempt to write into a read-only directory. Test should fail without producing a file.""" + bad_dir = out_dir / "nowrite" + bad_dir.mkdir() + set_readonly(bad_dir) + bad_file = bad_dir / "no_perm.tosa.mlirbc" + try: + proc = subprocess.run( + [cli_path, str(mobilenet_tflite), "--tosa-output-path", str(bad_file)], + capture_output=True, + text=True, + timeout=DEFAULT_TIMEOUT, + ) + stderr = proc.stderr or "" + assert proc.returncode != 0 + assert re.search(r"cannot open output file", stderr.lower()) + assert not bad_file.exists() + finally: + restore_permissions(bad_dir) + + +@pytest.mark.cli +def test_output_to_non_directory( + cli_path: str, mobilenet_tflite: Path, out_dir: Path +) -> None: + """Writing under a path whose parent is a file should error.""" + not_a_dir = out_dir / "not_a_dir" + not_a_dir.write_text("I am a file") + target = not_a_dir / "foo.tosa.mlirbc" + result = subprocess.run( + [cli_path, str(mobilenet_tflite), "--tosa-output-path", str(target)], + capture_output=True, + text=True, + timeout=DEFAULT_TIMEOUT, + ) + stderr = result.stderr or "" + assert result.returncode != 0 + assert re.search(r"not a directory|cannot", stderr.lower()) + assert not target.exists() + + +@pytest.mark.cli +def test_readonly_existing_file( + cli_path: str, mobilenet_tflite: Path, out_dir: Path +) -> None: + """Overwriting a read-only file should fail.""" + existing = out_dir / "existing.tosa.mlirbc" + existing.write_bytes(b"\x00") + set_readonly(existing) + try: + r = subprocess.run( + [cli_path, str(mobilenet_tflite), "--tosa-output-path", str(existing)], + capture_output=True, + text=True, + timeout=DEFAULT_TIMEOUT, + ) + stderr = r.stderr or "" + assert r.returncode != 0 + assert re.search(r"cannot open output file", stderr.lower()) + finally: + restore_permissions(existing) diff --git a/tosa_converter_for_tflite/cli.py b/tosa_converter_for_tflite/cli.py new file mode 100644 index 0000000..bb4e921 --- /dev/null +++ b/tosa_converter_for_tflite/cli.py @@ -0,0 +1,185 @@ +# +# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +# +# SPDX-License-Identifier: Apache-2.0 +# +""" +TOSA Converter for TFLite CLI + +This script provides a command-line interface to convert a TFLite +model (.tflite) into TOSA bytecode. + +It supports reading input from a file or from standard input (`stdin`), +and writing output to a specified path. + +Example usage: + $ tosa-converter-for-tflite --tosa-output-path out.tosa.mlirbc model.tflite + $ cat model.tflite | tosa-converter-for-tflite --tosa-output-path out.tosa.mlirbc +""" +import argparse +import logging +import sys +import time +from pathlib import Path + +from _tosa_converter_for_tflite_wrapper import tflite_flatbuffer_to_tosa_mlir_bytecode + +# ----------------------------------------------------------------------------- +# Logging Configuration +# ----------------------------------------------------------------------------- +logger = logging.getLogger("tosa_cli") +handler = logging.StreamHandler() +formatter = logging.Formatter("[%(levelname)s] %(message)s") +handler.setFormatter(formatter) +logger.addHandler(handler) +logger.setLevel(logging.INFO) + + +def parse_arguments() -> argparse.Namespace: + """ + Define and parse CLI arguments. + + Returns: + argparse.Namespace: Parsed arguments. + """ + parser = argparse.ArgumentParser( + description=( + "Convert a TensorFlow Lite (.tflite) model into TOSA bytecode format.\n\n" + "Examples:\n" + " tosa-converter-for-tflite model.tflite --tosa-output-path out.tosa.mlirbc\n" + " cat model.tflite | tosa-converter-for-tflite --tosa-output-path out.tosa.mlirbc" + ), + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + parser.add_argument( + "tflite_model_path", + nargs="?", + default="-", + help="Path to the input .tflite model file. Use '-' or omit to read from stdin.", + ) + + parser.add_argument( + "-o", + "--tosa-output-path", + help="Path to the output file where the TOSA bytecode will be saved.", + ) + + parser.add_argument( + "--verbose", + action="store_true", + help="Enable debug-level logging output.", + ) + + parser.add_argument( + "--measure-time", + action="store_true", + help="Measure and report conversion time.", + ) + + return parser.parse_args() + + +def validate_input_path(path: str) -> None: + """ + Validate input model path, unless reading from stdin. + + Args: + path (str): Input file path or "-" + + Raises: + ValueError: if file doesn't exist + """ + if path == "-": + return + + p = Path(path) + + if not p.is_file(): + raise ValueError(f"Input file does not exist: {path}") + + +def convert_model(input_path: str, output_path: str) -> None: + """ + Convert a TFLite model to TOSA bytecode and write to a file. + + Args: + input_path (str): Path to input .tflite model, or '-' for stdin. + output_path (str): Path to write the converted TOSA bytecode. + """ + logger.debug("Invoking converter:") + logger.debug(f" Input Path: {input_path}") + logger.debug(f" Output Path: {output_path or 'stdout'}") + + # NOTE: If input_path is '-', the underlying C++ implementation + # will automatically read from stdin. This behavior is implemented + # by MemoryBuffer::getFileOrSTDIN + # Also, if output path is none, the output is written to stdout + if output_path is None: + tflite_flatbuffer_to_tosa_mlir_bytecode(input_path, sys.stdout.buffer) + logger.info("TOSA bytecode written to stdout") + else: + tflite_flatbuffer_to_tosa_mlir_bytecode(input_path, output_path) + logger.info(f"TOSA bytecode written to: {output_path}") + + +def setup_logging(verbose: bool) -> None: + """ + Configure global logging level and format. + """ + level = logging.DEBUG if verbose else logging.INFO + logger.setLevel(level) + + +def run_cli(args: argparse.Namespace) -> int: + """ + Main logic for CLI execution. + + Args: + args (argparse.Namespace): Parsed CLI arguments. + + Returns: + int: Exit code (0 for success, 1 for failure, 2 for usage error). + """ + setup_logging(args.verbose) + + if args.tflite_model_path == "-" and sys.stdin.isatty(): + logger.error( + "Standard input is empty. Pipe a tflite model or provide a file path." + ) + return 2 + + try: + validate_input_path(args.tflite_model_path) + + if args.measure_time: + logger.debug("Timing enabled: measuring conversion duration") + start = time.time() + + convert_model(args.tflite_model_path, args.tosa_output_path) + + if args.measure_time: + elapsed = time.time() - start + logger.info(f"Conversion time: {elapsed:.2f} seconds") + + return 0 + + except ValueError as ve: + logger.error(f"Invalid argument: {ve}") + return 2 + except Exception as e: + logger.error(f"Conversion failed: {e}") + return 1 + + +def main() -> None: + """ + CLI entry point. + """ + args = parse_arguments() + exit_code = run_cli(args) + sys.exit(exit_code) + + +if __name__ == "__main__": + main() -- GitLab