diff --git a/README.md b/README.md index 8aafa09412446e082400c36458ca817a387f8488..ae763116ea148713dcf26acbe98f1af33bbca00b 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ # TOSA converter for TFLite -A tool to legalize TFLite FlatBuffer to TOSA MLIR Bytecode. +A tool to legalize TFLite FlatBuffer to TOSA MLIR Bytecode or Text. ## Python wheel creation @@ -23,42 +23,56 @@ The wheel should now be in the `dist` directory. ## Usage +### Python API + ```python -from tosa_converter_for_tflite import tflite_flatbuffer_to_tosa_mlir_bytecode +from tosa_converter_for_tflite import tflite_flatbuffer_to_tosa_mlir, TosaConverterOutputFormat + +# default (bytecode) output +tflite_flatbuffer_to_tosa_mlir("model_input.tflite", "model_output.tosa.mlirbc") -tflite_flatbuffer_to_tosa_mlir_bytecode("model_input.tflite", "model_output.tosa.mlirbc") +# textual MLIR output +tflite_flatbuffer_to_tosa_mlir( + "model_input.tflite", + "model_output.mlir", + TosaConverterOutputFormat.Text +) ``` -## Command-line Interface (CLI) +### Command-line Interface (CLI) The `tosa-converter-for-tflite` CLI provides a simple way to convert `.tflite` models -into TOSA MLIR bytecode. It supports both file-based and stream-based workflows +into TOSA MLIR(bytecode or text). It supports both file-based and stream-based workflows for easy integration into toolchains or pipelines. ### Usage Examples -**Convert a `.tflite` file to a TOSA bytecode file:** +**Convert a `.tflite` file to a TOSA bytecode file(default):** ```bash -tosa-converter-for-tflite model.tflite -o out.tosa.mlirbc +tosa-converter-for-tflite model.tflite --bytecode -o out.tosa.mlirbc ``` -**Read the model from `stdin` and write to a file:** +**Read the model from file and write textual TOSA MLIR to file** ```bash -cat model.tflite | tosa-converter-for-tflite -o out.tosa.mlirbc +tosa-converter-for-tflite model.tflite --text -o out.mlir ``` -**Read from a file and write to `stdout`:** +**Read the model from a file and write textual TOSA MLIR to stdout** ```bash -tosa-converter-for-tflite model.tflite > out.tosa.mlirbc +tosa-converter-for-tflite model.tflite --text ``` -**Fully stream the model through the CLI and into another tool:** +**Stream the model from stdin and write bytecode to file** +```bash +cat model.tflite | tosa-converter-for-tflite --bytecode -o out.tosa.mlirbc +``` +**Fully stream the model from stdin and write textual MLIR to stdout** ```bash -cat model.tflite | tosa-converter-for-tflite | +cat model.tflite | tosa-converter-for-tflite --text ``` ## Supported Platforms diff --git a/tests/test_cli.hjson b/tests/test_cli.hjson index a9850d66dd224facdf931999c6d01e766637989f..fee5614d8d0cc61650318c3cd306e6cea3572e77 100644 --- a/tests/test_cli.hjson +++ b/tests/test_cli.hjson @@ -170,6 +170,46 @@ out_fname: null }, + { + name: text_flag_file_output + description: "Use --text to emit human readable MLIR to a file" + args: [ + "--text" + "{model}" + "--tosa-output-path" + "{out}/out_text.mlir" + ] + stdin: null + expect_success: true + out_fname: "out_text.mlir" + }, + + { + name: text_flag_stdout + description: "Use --text to emit human readable MLIR to stdout" + args: [ + "--text" + "{model}" + ] + stdin: null + expect_success: true + out_fname: null + }, + + { + name: bytecode_flag_file_output + description: "Use --bytecode to emit MLIR bytecode to a file" + args: [ + "--bytecode" + "{model}" + "--tosa-output-path" + "{out}/out_bytecode.tosa.mlirbc" + ] + stdin: null + expect_success: true + out_fname: "out_bytecode.tosa.mlirbc" + }, + # ───────────────────────────────────────────────────────────────────────────── # Expected failure scenarios # ───────────────────────────────────────────────────────────────────────────── @@ -289,4 +329,19 @@ out_fname: null }, + { + name: mutually_exclusive_format_flags + description: "Specifying both --text and --bytecode must error out" + args: [ + "--text" + "--bytecode" + "{model}" + "--tosa-output-path" + "{out}/dummy" + ] + stdin: null + expect_success: false + out_fname: null + error_pattern: "not allowed with argument" + }, ] diff --git a/tosa_converter_for_tflite/__init__.py b/tosa_converter_for_tflite/__init__.py index b0a305a8d038d54e8d69b5526fe5c65db6f87d66..36440f1ec299517daa99bed966f357618458f8ec 100644 --- a/tosa_converter_for_tflite/__init__.py +++ b/tosa_converter_for_tflite/__init__.py @@ -3,12 +3,14 @@ # # SPDX-License-Identifier: Apache-2.0 # -from _tosa_converter_for_tflite_wrapper import _tflite_flatbuffer_to_tosa_mlir_bytecode -from _tosa_converter_for_tflite_wrapper import tflite_flatbuffer_to_tosa_mlir_bytecode +from _tosa_converter_for_tflite_wrapper import _tflite_flatbuffer_to_tosa_mlir +from _tosa_converter_for_tflite_wrapper import tflite_flatbuffer_to_tosa_mlir +from _tosa_converter_for_tflite_wrapper import TosaConverterOutputFormat __all__ = [ - "_tflite_flatbuffer_to_tosa_mlir_bytecode", - "tflite_flatbuffer_to_tosa_mlir_bytecode", + "_tflite_flatbuffer_to_tosa_mlir", + "tflite_flatbuffer_to_tosa_mlir", + "TosaConverterOutputFormat", ] try: diff --git a/tosa_converter_for_tflite/cli.py b/tosa_converter_for_tflite/cli.py index bb4e9211c755b60d0a8034cc7f01df1e46be627b..2fdc057fae8f2ded2ec897e63ff10f24240863b5 100644 --- a/tosa_converter_for_tflite/cli.py +++ b/tosa_converter_for_tflite/cli.py @@ -7,14 +7,15 @@ TOSA Converter for TFLite CLI This script provides a command-line interface to convert a TFLite -model (.tflite) into TOSA bytecode. +model (.tflite) into TOSA MLIR (bytecode or text). It supports reading input from a file or from standard input (`stdin`), and writing output to a specified path. Example usage: - $ tosa-converter-for-tflite --tosa-output-path out.tosa.mlirbc model.tflite - $ cat model.tflite | tosa-converter-for-tflite --tosa-output-path out.tosa.mlirbc + $ tosa-converter-for-tflite model.tflite --bytecode -o out.tosa.mlirbc + $ tosa-converter-for-tflite model.tflite --text -o out.tosa.mlir + $ cat model.tflite | tosa-converter-for-tflite --text """ import argparse import logging @@ -22,7 +23,8 @@ import sys import time from pathlib import Path -from _tosa_converter_for_tflite_wrapper import tflite_flatbuffer_to_tosa_mlir_bytecode +from _tosa_converter_for_tflite_wrapper import tflite_flatbuffer_to_tosa_mlir +from _tosa_converter_for_tflite_wrapper import TosaConverterOutputFormat # ----------------------------------------------------------------------------- # Logging Configuration @@ -44,10 +46,11 @@ def parse_arguments() -> argparse.Namespace: """ parser = argparse.ArgumentParser( description=( - "Convert a TensorFlow Lite (.tflite) model into TOSA bytecode format.\n\n" + "Convert a TensorFlow Lite (.tflite) model into TOSA MLIR (bytecode or text).\n\n" "Examples:\n" - " tosa-converter-for-tflite model.tflite --tosa-output-path out.tosa.mlirbc\n" - " cat model.tflite | tosa-converter-for-tflite --tosa-output-path out.tosa.mlirbc" + " tosa-converter-for-tflite model.tflite --bytecode -o out.tosa.mlirbc\n" + " tosa-converter-for-tflite model.tflite --text -o out.tosa.mlir\n" + " cat model.tflite | tosa-converter-for-tflite --text" ), formatter_class=argparse.RawDescriptionHelpFormatter, ) @@ -59,10 +62,23 @@ def parse_arguments() -> argparse.Namespace: help="Path to the input .tflite model file. Use '-' or omit to read from stdin.", ) + # Mutually exclusive flags for output format + grp = parser.add_mutually_exclusive_group() + grp.add_argument( + "--bytecode", + action="store_true", + help="Emit MLIR bytecode (default if neither flag is given).", + ) + grp.add_argument( + "--text", + action="store_true", + help="Emit human-readable MLIR textual form.", + ) + parser.add_argument( "-o", "--tosa-output-path", - help="Path to the output file where the TOSA bytecode will be saved.", + help="Path to the output file where the TOSA MLIR will be saved.", ) parser.add_argument( @@ -99,28 +115,34 @@ def validate_input_path(path: str) -> None: raise ValueError(f"Input file does not exist: {path}") -def convert_model(input_path: str, output_path: str) -> None: +def convert_model( + input_path: str, + output_path: str | None, + fmt: TosaConverterOutputFormat, +) -> None: """ - Convert a TFLite model to TOSA bytecode and write to a file. + Convert a TFLite model to TOSA MLIR and write to a file or stdout. Args: input_path (str): Path to input .tflite model, or '-' for stdin. - output_path (str): Path to write the converted TOSA bytecode. + output_path (str | None): Path to write the converted TOSA MLIR. + fmt (TosaConverterOutputFormat): Bytecode or Text. """ logger.debug("Invoking converter:") logger.debug(f" Input Path: {input_path}") logger.debug(f" Output Path: {output_path or 'stdout'}") + logger.debug(f" Output Format: {fmt.name}") # NOTE: If input_path is '-', the underlying C++ implementation # will automatically read from stdin. This behavior is implemented # by MemoryBuffer::getFileOrSTDIN # Also, if output path is none, the output is written to stdout if output_path is None: - tflite_flatbuffer_to_tosa_mlir_bytecode(input_path, sys.stdout.buffer) - logger.info("TOSA bytecode written to stdout") + tflite_flatbuffer_to_tosa_mlir(input_path, sys.stdout.buffer, fmt) + logger.info(f"TOSA MLIR ({fmt.name.lower()}) written to stdout") else: - tflite_flatbuffer_to_tosa_mlir_bytecode(input_path, output_path) - logger.info(f"TOSA bytecode written to: {output_path}") + tflite_flatbuffer_to_tosa_mlir(input_path, output_path, fmt) + logger.info(f"TOSA MLIR ({fmt.name.lower()}) written to: {output_path}") def setup_logging(verbose: bool) -> None: @@ -152,11 +174,22 @@ def run_cli(args: argparse.Namespace) -> int: try: validate_input_path(args.tflite_model_path) + # Determine format (default to Bytecode) + fmt = ( + TosaConverterOutputFormat.Text + if args.text + else TosaConverterOutputFormat.Bytecode + ) + if args.measure_time: logger.debug("Timing enabled: measuring conversion duration") start = time.time() - convert_model(args.tflite_model_path, args.tosa_output_path) + convert_model( + args.tflite_model_path, + args.tosa_output_path, + fmt, + ) if args.measure_time: elapsed = time.time() - start diff --git a/tosa_converter_for_tflite/tosa_converter_for_tflite.cc b/tosa_converter_for_tflite/tosa_converter_for_tflite.cc index 3a6f85627be8ed1b4a230e3eb6d08f4d6a20332c..7cd2791e6116c5ce1b5a143ad8a6f1619ce3d366 100644 --- a/tosa_converter_for_tflite/tosa_converter_for_tflite.cc +++ b/tosa_converter_for_tflite/tosa_converter_for_tflite.cc @@ -76,8 +76,9 @@ void LegalizeTFLToTOSA(mlir::ModuleOp mlir_module) { namespace tosa_converter_for_tflite { -void TFLiteFlatBufferToTosaMLIRBytecode( +void TFLiteFlatBufferToTosaMLIR( const std::string &tflite_input_file, const std::string &tosa_output_file, + TosaConverterOutputFormat output_format, const std::unordered_map> &override_tflite_input_shape) { std::string error_message; @@ -88,14 +89,15 @@ void TFLiteFlatBufferToTosaMLIRBytecode( } // Delegate to streaming overload - TFLiteFlatBufferToTosaMLIRBytecode(tflite_input_file, output->os(), - override_tflite_input_shape); + TFLiteFlatBufferToTosaMLIR(tflite_input_file, output->os(), output_format, + override_tflite_input_shape); output->keep(); } -void TFLiteFlatBufferToTosaMLIRBytecode( +void TFLiteFlatBufferToTosaMLIR( const std::string &tflite_input_file, llvm::raw_ostream &tosa_output_stream, + TosaConverterOutputFormat output_format, const std::unordered_map> &override_tflite_input_shape) { mlir::MLIRContext context; @@ -103,10 +105,16 @@ void TFLiteFlatBufferToTosaMLIRBytecode( OverrideTFLiteInputShape(*mlir_module, override_tflite_input_shape); LegalizeTFLToTOSA(*mlir_module); - const mlir::LogicalResult writeRes = - mlir::writeBytecodeToFile(*mlir_module, tosa_output_stream); - if (mlir::failed(writeRes)) { - throw std::runtime_error("Could not write bytecode output."); + if (output_format == TosaConverterOutputFormat::Bytecode) { + if (mlir::failed( + mlir::writeBytecodeToFile(*mlir_module, tosa_output_stream))) { + throw std::runtime_error("Could not write bytecode output."); + } + } else if (output_format == TosaConverterOutputFormat::Text) { + mlir_module->print(tosa_output_stream); + tosa_output_stream.flush(); + } else { + throw std::invalid_argument("Unsupported TosaConverterOutputFormat"); } } diff --git a/tosa_converter_for_tflite/tosa_converter_for_tflite.h b/tosa_converter_for_tflite/tosa_converter_for_tflite.h index 77dbfbed3b642ac046290cfaf24f4f55c816cf66..a89b3af8d0140fd28cc91f852622b13e1bd18db3 100644 --- a/tosa_converter_for_tflite/tosa_converter_for_tflite.h +++ b/tosa_converter_for_tflite/tosa_converter_for_tflite.h @@ -16,15 +16,22 @@ namespace tosa_converter_for_tflite { -void TFLiteFlatBufferToTosaMLIRBytecode( +// Control whether we emit bytecode or human readable MLIR. +enum class TosaConverterOutputFormat { Bytecode, Text }; + +void TFLiteFlatBufferToTosaMLIR( const std::string& tflite_input_file, const std::string& tosa_output_file, + TosaConverterOutputFormat output_format = + TosaConverterOutputFormat::Bytecode, const std::unordered_map>& - override_tflite_input_shape); + override_tflite_input_shape = {}); -void TFLiteFlatBufferToTosaMLIRBytecode( +void TFLiteFlatBufferToTosaMLIR( const std::string& tflite_input_file, llvm::raw_ostream& tosa_output_stream, + TosaConverterOutputFormat output_format = + TosaConverterOutputFormat::Bytecode, const std::unordered_map>& - override_tflite_input_shape); + override_tflite_input_shape = {}); } // namespace tosa_converter_for_tflite diff --git a/tosa_converter_for_tflite/tosa_converter_for_tflite_pybind11.cc b/tosa_converter_for_tflite/tosa_converter_for_tflite_pybind11.cc index afbc7b1bbd5ecec3fcfa4a102d6b6f4aea4a6025..1f40c34b674e579177477ee81bf8820f36893d70 100644 --- a/tosa_converter_for_tflite/tosa_converter_for_tflite_pybind11.cc +++ b/tosa_converter_for_tflite/tosa_converter_for_tflite_pybind11.cc @@ -18,26 +18,26 @@ namespace { -/// Dispatch conversion from TFLite to TOSA MLIR bytecode. +/// Dispatch conversion from TFLite to TOSA MLIR. /// Decides if output is a filesystem path or a Python binary stream. void DispatchTFLiteToTosaConversion( - const std::string& tflite_input_file, - const pybind11::object& tosa_bytecode_sink, + const std::string& tflite_input_file, const pybind11::object& tosa_sink, + tosa_converter_for_tflite::TosaConverterOutputFormat output_format, const std::unordered_map>& override_shapes) { // Case 1: output is a filesystem path - if (pybind11::isinstance(tosa_bytecode_sink)) { - std::string path = tosa_bytecode_sink.cast(); - tosa_converter_for_tflite::TFLiteFlatBufferToTosaMLIRBytecode( - tflite_input_file, path, override_shapes); + if (pybind11::isinstance(tosa_sink)) { + std::string path = tosa_sink.cast(); + tosa_converter_for_tflite::TFLiteFlatBufferToTosaMLIR( + tflite_input_file, path, output_format, override_shapes); return; } // Case 2: any object with a write(bytes) method - if (pybind11::hasattr(tosa_bytecode_sink, "write")) { + if (pybind11::hasattr(tosa_sink, "write")) { /* * Using the "buffer-then-write" approach when writing to object: - * 1) LLVM emits all bytecode into a single std::string + * 1) LLVM emits output into a single std::string * 2) Once complete, we copy that buffer to a Python io.BinaryIO * object in one call. * @@ -48,23 +48,23 @@ void DispatchTFLiteToTosaConversion( std::string buffer; llvm::raw_string_ostream llvm_stream(buffer); - tosa_converter_for_tflite::TFLiteFlatBufferToTosaMLIRBytecode( - tflite_input_file, llvm_stream, override_shapes); + tosa_converter_for_tflite::TFLiteFlatBufferToTosaMLIR( + tflite_input_file, llvm_stream, output_format, override_shapes); llvm_stream.flush(); auto view = pybind11::memoryview::from_buffer( buffer.data(), {static_cast(buffer.size())}, {static_cast(1)}); - tosa_bytecode_sink.attr("write")(view); + tosa_sink.attr("write")(view); // If they also provide a flush(), call it. - if (pybind11::hasattr(tosa_bytecode_sink, "flush")) { - tosa_bytecode_sink.attr("flush")(); + if (pybind11::hasattr(tosa_sink, "flush")) { + tosa_sink.attr("flush")(); } return; } // Otherwise: invalid type, raise Python TypeError. PyErr_SetString(PyExc_TypeError, - "`tosa_bytecode_sink` must be either a str (file path) " + "`tosa_sink` must be either a str (file path) " "or any object with a write(bytes) method (i.e. BinaryIO)"); throw pybind11::error_already_set(); } @@ -72,20 +72,41 @@ void DispatchTFLiteToTosaConversion( } // namespace PYBIND11_MODULE(_tosa_converter_for_tflite_wrapper, m) { + pybind11::enum_( + m, "TosaConverterOutputFormat", + R"doc( + TosaConverterOutputFormat enumerates the two supported output modes: + + - Bytecode: emit raw MLIR bytecode. + - Text: emit human readable MLIR text. + )doc") + .value("Bytecode", + tosa_converter_for_tflite::TosaConverterOutputFormat::Bytecode, + "Emit MLIR as raw bytecode.") + .value("Text", tosa_converter_for_tflite::TosaConverterOutputFormat::Text, + "Emit human readable MLIR text.") + .export_values(); + m.def( - "tflite_flatbuffer_to_tosa_mlir_bytecode", - [](const std::string& tflite_input_file, - pybind11::object tosa_bytecode_sink) { - DispatchTFLiteToTosaConversion(tflite_input_file, tosa_bytecode_sink, - {}); + "tflite_flatbuffer_to_tosa_mlir", + [](const std::string& tflite_input_file, pybind11::object tosa_sink, + tosa_converter_for_tflite::TosaConverterOutputFormat output_format = + tosa_converter_for_tflite::TosaConverterOutputFormat::Bytecode) { + DispatchTFLiteToTosaConversion(tflite_input_file, tosa_sink, + output_format, {}); }, - pybind11::arg("tflite_input_file"), pybind11::arg("tosa_bytecode_sink"), + pybind11::arg("tflite_input_file"), pybind11::arg("tosa_sink"), + pybind11::arg("output_format") = + tosa_converter_for_tflite::TosaConverterOutputFormat::Bytecode, R"doc( - Converts a TFLite flatbuffer (.tflite) into TOSA MLIR bytecode. + Converts a TFLite flatbuffer (.tflite) into TOSA MLIR. Args: - tflite_input_file (str): Path to the input TFLite model file. - tosa_bytecode_sink (str or BinaryIO): Path to output tosa file, or a writable binary stream (e.g., BytesIO). + tflite_input_file (str): Path to the input TFLite model. + tosa_sink (str or BinaryIO): Path to output tosa file, or a writable binary stream (e.g., BytesIO). + output_format (TosaConverterOutputFormat, optional): + - TosaConverterOutputFormat.Bytecode: emit MLIR bytecode (default) + - TosaConverterOutputFormat.Text: emit human-readable MLIR text )doc"); /* @@ -93,22 +114,29 @@ PYBIND11_MODULE(_tosa_converter_for_tflite_wrapper, m) { * inference should be preferred */ m.def( - "_tflite_flatbuffer_to_tosa_mlir_bytecode", - [](const std::string& tflite_input_file, - pybind11::object tosa_bytecode_sink, + "_tflite_flatbuffer_to_tosa_mlir", + [](const std::string& tflite_input_file, pybind11::object tosa_sink, const std::unordered_map>& - override_tflite_input_shape) { - DispatchTFLiteToTosaConversion(tflite_input_file, tosa_bytecode_sink, + override_tflite_input_shape, + tosa_converter_for_tflite::TosaConverterOutputFormat output_format = + tosa_converter_for_tflite::TosaConverterOutputFormat::Bytecode) { + DispatchTFLiteToTosaConversion(tflite_input_file, tosa_sink, + output_format, override_tflite_input_shape); }, - pybind11::arg("tflite_input_file"), pybind11::arg("tosa_bytecode_sink"), + pybind11::arg("tflite_input_file"), pybind11::arg("tosa_sink"), + pybind11::arg("output_format") = + tosa_converter_for_tflite::TosaConverterOutputFormat::Bytecode, pybind11::arg("override_tflite_input_shape"), R"doc( - Converts a TFLite flatbuffer into TOSA MLIR bytecode with explicit input shape overrides. + Converts a TFLite flatbuffer into TOSA MLIR with explicit input shape overrides. Args: - tflite_input_file (str): Path to the input TFLite model file. - tosa_bytecode_sink (str or BinaryIO): Path to output tosa file, or a writable binary stream (e.g., BytesIO). + tflite_input_file (str): Path to the input TFLite file. + tosa_sink (str or BinaryIO): Path to output tosa file, or a writable binary stream (e.g., BytesIO). + output_format (TosaConverterOutputFormat, optional): + - TosaConverterOutputFormat.Bytecode: emit MLIR bytecode (default) + - TosaConverterOutputFormat.Text: emit human-readable MLIR text override_shapes (dict): Mapping of input tensor names to their overridden shape. )doc"); }