From c4a4624e966bc887a2220389454638c4501997d4 Mon Sep 17 00:00:00 2001 From: Yunus Kalkan Date: Wed, 18 Jun 2025 18:24:24 +0100 Subject: [PATCH] ONNX RT-GenAI integration into LLM * Add ONNX RT-GenAI backend (v0.8.2) support to LLM * Build ONNX Runtime (v1.22.0) as part of ONNX Runtime-GenAI Integration * Introduce general KleidiAI flag (-DUSE_KLEIDIAI) and configure framework-specific KleidiAI flags * Update README with ONNX Runtime-GenAI backend usage and build instructions * Disable -DLLAMA_CURL by default Change-Id: I6812a3fafb78f55043a7653c7dce267a20670cf3 Signed-off-by: Yunus Kalkan --- CMakeLists.txt | 6 +- README.md | 105 +++++- model_configuration_files/llamaConfig.txt | 4 +- model_configuration_files/onnxrtConfig.txt | 6 + scripts/cmake/check-flag.cmake | 20 ++ scripts/cmake/configuration-options.cmake | 7 +- scripts/cmake/configuration-presets.json | 69 ++-- scripts/cmake/download-resources.cmake | 2 + scripts/py/download_resources.py | 27 +- scripts/py/requirements.json | 77 +++- src/cpp/Llm.cpp | 5 + src/cpp/LlmJni.cpp | 21 +- src/cpp/config/LlmConfig.cpp | 24 +- src/cpp/config/LlmConfig.hpp | 32 ++ src/cpp/frameworks/CMakeLists.txt | 6 +- src/cpp/frameworks/llama_cpp/CMakeLists.txt | 23 +- src/cpp/frameworks/llama_cpp/LlamaImpl.cpp | 6 + src/cpp/frameworks/llama_cpp/LlmImpl.hpp | 20 +- .../onnxruntime_genai/CMakeLists.txt | 177 ++++++++++ .../frameworks/onnxruntime_genai/LlmImpl.hpp | 204 +++++++++++ .../onnxruntime_genai/OnnxrtImpl.cpp | 333 ++++++++++++++++++ src/cpp/interface/Llm.hpp | 6 + src/java/com/arm/Llm.java | 30 +- src/java/com/arm/LlmConfig.java | 47 ++- test/CMakeLists.txt | 4 +- test/cpp/LlmTest.cpp | 3 +- test/cpp/LlmUtils.cpp | 6 + test/java/com/arm/LlmTestJNI.java | 55 +-- 28 files changed, 1162 insertions(+), 163 deletions(-) create mode 100755 model_configuration_files/onnxrtConfig.txt create mode 100644 src/cpp/frameworks/onnxruntime_genai/CMakeLists.txt create mode 100644 src/cpp/frameworks/onnxruntime_genai/LlmImpl.hpp create mode 100644 src/cpp/frameworks/onnxruntime_genai/OnnxrtImpl.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 000e87c..729649e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,10 +13,6 @@ if (NOT CMAKE_TOOLCHAIN_FILE) set(CMAKE_TOOLCHAIN_FILE "${CMAKE_CURRENT_SOURCE_DIR}/scripts/cmake/toolchains/native.cmake") endif() -set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) -set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) -set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) - message(STATUS "Using CMAKE_TOOLCHAIN_FILE: ${CMAKE_TOOLCHAIN_FILE}") # Declare project @@ -28,7 +24,7 @@ project(arm-llm-wrapper list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/scripts/cmake") # Set the runtime and lib directories -set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) +set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib/archive) set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) diff --git a/README.md b/README.md index 9250a17..adb5745 100644 --- a/README.md +++ b/README.md @@ -12,8 +12,12 @@ * [Prerequisites](#prerequisites) * [Configuration options](#configuration-options) * [Conditional options](#conditional-options) + * [llama cpp options](#llama-cpp-options) + * [onnxruntime genai options](#onnxruntime-genai-options) * [Quick start](#quick-start) * [Neural network](#neural-network) + * [llama cpp model](#llama-cpp-model) + * [onnxruntime genai model](#onnxruntime-genai-model) * [To build for Android](#to-build-for-android) * [To build for Linux](#to-build-for-linux) * [Generic aarch64 target](#generic-aarch64-target) @@ -21,6 +25,8 @@ * [Native host build](#native-host-build) * [Building and running tests](#building-and-running-tests) * [To build an executable](#to-build-an-executable) + * [llama cpp](#llama-cpp) + * [onnxruntime genai](#onnxruntime-genai) * [Trademarks](#trademarks) * [License](#license) @@ -29,8 +35,9 @@ This repo is designed for building an [Arm® KleidiAI™](https://www.arm.com/markets/artificial-intelligence/software/kleidi) enabled LLM library using CMake build system. It intends to provide an abstraction for different Machine Learning frameworks/backends that Arm® KleidiAI™ kernels have been integrated into. -Currently, it supports [llama.cpp](https://github.com/ggml-org/llama.cpp) backend but we intend to add support for -other backends soon. +Currently, it supports [llama.cpp](https://github.com/ggml-org/llama.cpp) and +[onnxruntime-genai](https://github.com/microsoft/onnxruntime-genai) backends but we intend to add +support for other backends, such as [mediapipe](https://github.com/google-ai-edge/mediapipe), soon. The backend library (selected at CMake configuration stage) is wrapped by this project's thin C++ layer that could be used directly for testing and evaluations. However, JNI bindings are also provided for developers targeting Android™ based @@ -53,11 +60,9 @@ applications. The project is designed to download the required software sources based on user provided configuration options. CMake presets are available to use and set the following variables: -- `LLM_DEP_NAME`: Currently supports only `llama.cpp`. Support for `mediapipe` and `executorch` may be added later. -- `BUILD_SHARED_LIBS`: Build shared instead of static dependency libraries, specifically - ggml and common, disabled by default. +- `LLM_FRAMEWORK`: Currently supports `llama.cpp` (default framework) and `onnxruntime-genai`. - `BUILD_JNI_LIB`: Build the JNI shared library that other projects can consume, enabled by default. - `BUILD_UNIT_TESTS`: Build C++ unit tests and add them to CTest, JNI tests will also be built, enabled by default. -- `LLAMA_BUILD_COMMON`: Build llama's dependency Common, enabled by default. - `BUILD_EXECUTABLE`: Build standalone applications, disabled by default. > **NOTE**: If you need specific version of Java set the path in `JAVA_HOME` environment variable. @@ -75,29 +80,83 @@ provided configuration options. CMake presets are available to use and set the f ### Conditional options -For `llama.cpp` as dependency, these configuration parameters can be set: +There are different conditional options for different frameworks. + +#### llama cpp options + +For `llama.cpp` as framework, these configuration parameters can be set: - `LLAMA_SRC_DIR`: Source directory path that will be populated by CMake configuration. - `LLAMA_GIT_URL`: Git URL to clone the sources from. - `LLAMA_GIT_SHA`: Git SHA for checkout. +- `LLAMA_BUILD_COMMON`: Build llama's dependency Common, enabled by default. +- `BUILD_SHARED_LIBS`: Build shared instead of static dependency libraries, specifically - ggml and common, disabled by default. +- `LLAMA_CURL`: Enable HTTP transport via libcurl for remote models or features requiring network communication, disabled by default. + +#### onnxruntime genai options + +When using `onnxruntime-genai`, the `onnxruntime` dependency will be built from source. To customize +the versions of both `onnxruntime` and `onnxruntime-genai`, the following configuration parameters +can be used: + +onnxruntime: +- `ONNXRUNTIME_SRC_DIR`: Source directory path that will be populated by CMake + configuration. +- `ONNXRUNTIME_GIT_URL`: Git URL to clone the sources from. +- `ONNXRUNTIME_GIT_TAG`: Git SHA for checkout. + +onnxruntime-genai: +- `ONNXRT_GENAI_SRC_DIR`: Source directory path that will be populated by CMake + configuration. +- `ONNXRT_GENAI_GIT_URL`: Git URL to clone the sources from. +- `ONNXRT_GENAI_GIT_TAG`: Git SHA for checkout. + +> **NOTE**: This repository has been tested with `onnxruntime` version `v1.22.0` and +`onnxruntime-genai` version `v0.8.2`. ## Quick start By default, the JNI builds are enabled, and Arm® KleidiAI™ kernels are enabled on arm64/aarch64. -To disable these, configure with: `-DGGML_CPU_KLEIDIAI=OFF`. +To disable these, configure with: `-DUSE_KLEIDIAI=OFF`. ### Neural network -This project uses the **phi-2 model** as its default network. The model is distributed using the -**Q4_0 quantization format**, which is highly recommended as it delivers effective inference times by striking a -balance between computational efficiency and model performance. +There are different default model for different frameworks. + +#### llama cpp model + +This project uses the **phi-2 model** as its default network for `llama.cpp` framework. +The model is distributed using the **Q4_0 quantization format**, which is highly recommended as it +delivers effective inference times by striking a balance between computational efficiency and model performance. - You can access the model from [Hugging Face](https://huggingface.co/ggml-org/models/blob/main/phi-2/ggml-model-q4_0.gguf). - The default model configuration is declared in the [`requirements.json`](scripts/py/requirements.json) file. However, any model supported by the backend library could be used. -> **NOTE**: Currently only Q4_0 models are accelerated by Arm® KleidiAI™ kernels in llama.cpp. +> **NOTE**: Currently only Q4_0 models are accelerated by Arm® KleidiAI™ kernels in `llama.cpp`. + +#### onnxruntime genai model + +This project uses the **Phi-4-mini-instruct-onnx** as its default network for `onnxruntime-genai` framework. +The model is distributed using **int4 quantization format** with the **block size: 32**, which is highly recommended as it +delivers effective inference times by striking a balance between computational efficiency and model performance. + +- You can access the model from [Hugging Face](https://huggingface.co/microsoft/Phi-4-mini-instruct-onnx/tree/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4). +- The default model configuration is declared in the [`requirements.json`](scripts/py/requirements.json) file. + +However, any model supported by the backend library could be used. + +To use an ONNX model with this framework, the following files are required: +- `genai_config.json`: Configuration file +- `model_name.onnx`: ONNX model +- `model_name.onnx.data`: ONNX model data +- `tokenizer.json`: Tokenizer file +- `tokenizer_config.json`: Tokenizer config file + +These files are essential for loading and running ONNX models effectively. + +> **NOTE**: Currently only int4 and block size 32 models are accelerated by Arm® KleidiAI™ kernels in `onnxruntime-genai`. ### To build for Android For Android™ build, ensure the `NDK_PATH` is set to installed Android™ NDK, specify Android™ ABI and platform needed: @@ -114,7 +173,7 @@ cmake --build ./build ### To build for Linux -Building for Linux targets, with llama.cpp backend, `GGML_CPU_ARM_ARCH` can be set to provide the architecture flags. +Building for Linux targets, with `llama.cpp` backend, `GGML_CPU_ARM_ARCH` can be set to provide the architecture flags. #### Generic aarch64 target @@ -131,7 +190,7 @@ cmake --build ./build #### Aarch64 target with SME -To build for aarch64 Linux system with [Scalable Matrix Extensions](https://developer.arm.com/documentation/109246/0100/SME-Overview/SME-and-SME2), ensure `GGML_CPU_ARM_ARCH` is set with needed feature flags as below: +To build for aarch64 Linux system with [Scalable Matrix Extensions](https://developer.arm.com/documentation/109246/0100/SME-Overview/SME-and-SME2), for `llama.cpp` ensure `GGML_CPU_ARM_ARCH` is set with needed feature flags as below: ```shell cmake -B build \ @@ -145,16 +204,16 @@ cmake --build ./build Once built, a standalone application can be executed to get performance. If `FEAT_SME` is available on deployment target, environment variable `GGML_KLEIDIAI_SME` can be used to -toggle the use of SME kernels during execution. For example: +toggle the use of SME kernels during execution for `llama.cpp`. For example: ```shell -GGML_KLEIDIAI_SME=1 ./build/bin/llama-cli -m resources_downloaded/models/model.gguf -t 1 -p "What is a car?" +GGML_KLEIDIAI_SME=1 ./build/bin/llama-cli -m resources_downloaded/models/llama.cpp/model.gguf -t 1 -p "What is a car?" ``` To run without invoking SME kernels, set `GGML_KLEIDIAI_SME=0` during execution: ```shell -GGML_KLEIDIAI_SME=0 ./build/bin/llama-cli -m resources_downloaded/models/model.gguf -t 1 -p "What is a car?" +GGML_KLEIDIAI_SME=0 ./build/bin/llama-cli -m resources_downloaded/models/llama.cpp/model.gguf -t 1 -p "What is a car?" ``` > **NOTE**: In some cases, it may be desirable to build a statically linked executable. For llama.cpp backend @@ -181,6 +240,8 @@ cmake --build ./build ctest --test-dir ./build ``` +> **NOTE**: For consistent and reliable test results, avoid using the `--parallel` option when running tests. + This should produce something like: ```shell Internal ctest changing into directory: /home/user/llm/build @@ -221,12 +282,22 @@ cmake -B build \ cmake --build ./build ``` +### llama cpp + You can run either executable from command line and add your prompt for example the following: ``` -./build/bin/llama-cli -m resources_downloaded/models/model.gguf --prompt "What is the capital of France" +./build/bin/llama-cli -m resources_downloaded/models/llama.cpp/model.gguf --prompt "What is the capital of France" ``` More information can be found at `llama.cpp/examples/main/README.md` on how this executable can be run. +### onnxruntime genai + +You can run model_benchmark executable from command line: +``` +./build/bin/model_benchmark -i resources_downloaded/models/onnxruntime-genai +``` +More information can be found at `onnxruntime-genai/benchmark/c/readme.md` on how this executable can be run. + ## Trademarks * Arm® and KleidiAI™ are registered trademarks or trademarks of Arm® Limited (or its subsidiaries) in the US and/or diff --git a/model_configuration_files/llamaConfig.txt b/model_configuration_files/llamaConfig.txt index 457f2b6..b3b8a17 100755 --- a/model_configuration_files/llamaConfig.txt +++ b/model_configuration_files/llamaConfig.txt @@ -1,4 +1,6 @@ modelTag=Orbita: +userTag= +endTag= llmPrefix=Transcript of a dialog, where the User interacts with an AI Assistant named Orbita. Orbita is helpful, polite, honest, good at writing and answers honestly with a maximum of two sentences. User: stopWords=Orbita:,User:,AI:,<|user|>,Assistant:,user:,[end of text],<|endoftext|>,model:,Question:,"\n\n",Consider the following scenario:\n -llmModelName=model.gguf \ No newline at end of file +llmModelName=llama.cpp/model.gguf diff --git a/model_configuration_files/onnxrtConfig.txt b/model_configuration_files/onnxrtConfig.txt new file mode 100755 index 0000000..8b6ccd7 --- /dev/null +++ b/model_configuration_files/onnxrtConfig.txt @@ -0,0 +1,6 @@ +modelTag=<|assistant|> +userTag=<|user|> +endTag=<|end|> +llmPrefix=<|system|>Transcript of a dialog, where the User interacts with an AI Assistant named Orbita. Orbita is helpful, polite, honest, good at writing and answers honestly with a maximum of two sentences<|end|> +stopWords=Orbita:,User:,User,AI:,<|user|>,Assistant:,<|assistant|>,<|end|>,user:,[end of text],<|endoftext|>,model:,Question:,Consider the following scenario:\n +llmModelName=onnxruntime-genai diff --git a/scripts/cmake/check-flag.cmake b/scripts/cmake/check-flag.cmake index f714a2e..d37f10c 100644 --- a/scripts/cmake/check-flag.cmake +++ b/scripts/cmake/check-flag.cmake @@ -38,3 +38,23 @@ function(check_compiler_support LANG FLAG) message(STATUS "The compiler supports the ${LANG} flag ${FLAG}!") endif() endfunction() + +function(set_kleidiai_flag) + + # If the user has NOT explicitly set onnxruntime_USE_KLEIDIAI + if (NOT DEFINED USE_KLEIDIAI) + # if we are on arm64/aarch64, then default KleidiAI to ON. + if (CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm64|ARM64)$") + set(USE_KLEIDIAI ON CACHE BOOL + "Enable KleidiAI by default on ${CMAKE_SYSTEM_PROCESSOR}") + message(STATUS "KleidiAI enabled by default") + # if we are NOT on arm64/aarch64, then default KleidiAI to OFF. + else() + set(USE_KLEIDIAI OFF CACHE BOOL + "Disable KleidiAI by default on ${CMAKE_SYSTEM_PROCESSOR}") + message(STATUS "KleidiAI disabled by default") + endif() + else () + message(STATUS "KleidiAI: ${USE_KLEIDIAI}") + endif() +endfunction() \ No newline at end of file diff --git a/scripts/cmake/configuration-options.cmake b/scripts/cmake/configuration-options.cmake index 8f8f60d..17787c8 100644 --- a/scripts/cmake/configuration-options.cmake +++ b/scripts/cmake/configuration-options.cmake @@ -6,14 +6,13 @@ include_guard(GLOBAL) -set(LLM_DEP_NAME "llama.cpp" CACHE STRING +set(LLM_FRAMEWORK "llama.cpp" CACHE STRING "Dependency name to configure the project for") # Available options: -set(CACHE LLM_DEP_NAME PROPERTY STRINGS +set(CACHE LLM_FRAMEWORK PROPERTY STRINGS "llama.cpp" - "mediapipe" - "executorch") + "onnxruntime-genai") set(DOWNLOADS_DIR ${CMAKE_CURRENT_SOURCE_DIR}/resources_downloaded CACHE STRING diff --git a/scripts/cmake/configuration-presets.json b/scripts/cmake/configuration-presets.json index 9c6fd46..2d1cd70 100644 --- a/scripts/cmake/configuration-presets.json +++ b/scripts/cmake/configuration-presets.json @@ -50,39 +50,38 @@ } } }, - - { - "name": "llama.cpp", - "description": "Build with llama.cpp as the backend.", - "hidden": true, - "cacheVariables": { - "LLM_DEP_NAME": { - "type": "STRING", - "value": "llama.cpp" - } - } - }, - { - "name": "build-shared-enabled", - "description": "Configure with shared libraries enabled.", - "hidden": true, - "cacheVariables": { - "BUILD_SHARED_LIBS": { - "type": "BOOL", - "value": "ON" - } - } - }, - { - "name": "build-shared-disabled", - "description": "Configure with shared libraries disabled.", - "hidden": true, - "cacheVariables": { - "BUILD_SHARED_LIBS": { - "type": "BOOL", - "value": "OFF" - } - } - } - ] + { + "name": "llama.cpp", + "description": "Build with llama.cpp as the backend.", + "hidden": true, + "cacheVariables": { + "LLM_FRAMEWORK": { + "type": "STRING", + "value": "llama.cpp" + } + } + }, + { + "name": "build-shared-enabled", + "description": "Configure with shared libraries enabled.", + "hidden": true, + "cacheVariables": { + "BUILD_SHARED_LIBS": { + "type": "BOOL", + "value": "ON" + } + } + }, + { + "name": "build-shared-disabled", + "description": "Configure with shared libraries disabled.", + "hidden": true, + "cacheVariables": { + "BUILD_SHARED_LIBS": { + "type": "BOOL", + "value": "OFF" + } + } + } + ] } diff --git a/scripts/cmake/download-resources.cmake b/scripts/cmake/download-resources.cmake index 4b337e0..b523529 100644 --- a/scripts/cmake/download-resources.cmake +++ b/scripts/cmake/download-resources.cmake @@ -45,6 +45,8 @@ execute_process( ${CMAKE_CURRENT_LIST_DIR}/../py/requirements.json --download-dir ${DOWNLOADS_DIR} + --llm-framework + ${LLM_FRAMEWORK} RESULT_VARIABLE return_code) # Release the lock: diff --git a/scripts/py/download_resources.py b/scripts/py/download_resources.py index 22ccbc3..5794ab8 100644 --- a/scripts/py/download_resources.py +++ b/scripts/py/download_resources.py @@ -11,6 +11,7 @@ import urllib.request import logging import sys from argparse import ArgumentParser + def download_file(url: str, dest: Path) -> None: """ Download a file @@ -59,12 +60,20 @@ def download_resources(resources_file: Path, download_dir: Path) -> None: for resource_type in resource_list: resource_dir = Path(download_dir / resource_type) resource_dir.mkdir(exist_ok=True) - for resource_data in resource_list[resource_type]: - logging.info(f'Name: {resource_data["name"]}') - logging.info(f'Purpose: {resource_data["purpose"]}') - logging.info(f'Dest: {resource_data["destination"]}') - logging.info(f'URL: {resource_data["url"]}') - logging.info(f'SHA256: {resource_data["sha256sum"]}') + + if resource_type == "models": + resources = resource_list[resource_type][llm_framework] + model_dir = Path(resource_dir / llm_framework) + model_dir.mkdir(exist_ok=True) + else: + resources = resource_list[resource_type] + + for resource_data in resources: + logging.info(f'Name: {resource_data["name"]}') + logging.info(f'Purpose: {resource_data["purpose"]}') + logging.info(f'Dest: {resource_data["destination"]}') + logging.info(f'URL: {resource_data["url"]}') + logging.info(f'SHA256: {resource_data["sha256sum"]}') url = resource_data['url'] dest = resource_dir / resource_data['destination'] @@ -83,6 +92,7 @@ def download_resources(resources_file: Path, download_dir: Path) -> None: current_file_dir = Path(__file__).parent.resolve() default_requirements_file = current_file_dir / 'requirements.json' default_download_location = current_file_dir / '..' / '..' / 'resources_downloaded' +default_llm_framework = "llama.cpp" if __name__ == "__main__": @@ -98,10 +108,15 @@ if __name__ == "__main__": "--download-dir", help="Path to where resources should be downloaded.", default=default_download_location) + parser.add_argument( + "--llm-framework", + help="LLM framework from which the model will be downloaded.", + default=default_llm_framework) args = parser.parse_args() req_file = Path(args.requirements_file) download_dir = Path(args.download_dir) + llm_framework = args.llm_framework if not req_file.exists(): raise FileNotFoundError(f'{req_file} does not exist') diff --git a/scripts/py/requirements.json b/scripts/py/requirements.json index 1340e7c..a3f81fb 100644 --- a/scripts/py/requirements.json +++ b/scripts/py/requirements.json @@ -1,27 +1,70 @@ { - "models": [ - { + "models": { + "llama.cpp": + [ + { "name": "phi-2 model", "purpose": "To enable basic testing", - "destination": "model.gguf", + "destination": "llama.cpp/model.gguf", "url": "https://huggingface.co/ggml-org/models/resolve/main/phi-2/ggml-model-q4_0.gguf?download=true", "sha256sum": "fd506d24a4bee6997a566b02b65715af5cadb433c6a3a47a74b467badc5727ca" - } - ], - "test_utils": [ + } + ], + "onnxruntime-genai" : + [ + { + "name": "genai config file", + "purpose": "Model configuration file", + "destination": "onnxruntime-genai/genai_config.json", + "url": "https://huggingface.co/microsoft/Phi-4-mini-instruct-onnx/resolve/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/genai_config.json", + "sha256sum": "0fcfa1e663f2bc867f8dc62fae65dd0924f0a4d68b43d1234df742dd19171470" + }, + { + "name": "phi-4-mini-instruct-onnx (CPU/mobile)", + "purpose": "ONNX model for basic testing", + "destination": "onnxruntime-genai/model.onnx", + "url": "https://huggingface.co/microsoft/Phi-4-mini-instruct-onnx/resolve/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/model.onnx", + "sha256sum": "701aa5d185b6a782bc27104a990dd5b634fa507840b7c42f7ee6f1fb812d0b83" + }, + { + "name": "phi-4-mini-instruct-onnx.data (CPU/mobile)", + "purpose": "ONNX model data for basic testing", + "destination": "onnxruntime-genai/model.onnx.data", + "url": "https://huggingface.co/microsoft/Phi-4-mini-instruct-onnx/resolve/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/model.onnx.data", + "sha256sum": "cb0267fa60befa1a4ade8c98b6d32a3d67f51abbd307c7f793f132e8d9092131" + }, + { + "name": "tokenizer file", + "purpose": "Tokenizer for the model", + "destination": "onnxruntime-genai/tokenizer.json", + "url": "https://huggingface.co/microsoft/Phi-4-mini-instruct-onnx/resolve/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/tokenizer.json", + "sha256sum": "382cc235b56c725945e149cc25f191da667c836655efd0857b004320e90e91ea" + }, + { + "name": "tokenizer config file", + "purpose": "Tokenizer config for the model", + "destination": "onnxruntime-genai/tokenizer_config.json", + "url": "https://huggingface.co/microsoft/Phi-4-mini-instruct-onnx/resolve/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/tokenizer_config.json", + "sha256sum": "c565326a315fbe62cda093a59d298828c8f3f823122661325f41f3ba577a7dec" + } + ] + }, + + "test_utils": + [ { - "name": "hamcrest", - "purpose": "JUnit tests for validating the Java API", - "destination": "hamcrest-all-1.3.jar", - "url": "https://repo1.maven.org/maven2/org/hamcrest/hamcrest-all/1.3/hamcrest-all-1.3.jar", - "sha256sum": "4877670629ab96f34f5f90ab283125fcd9acb7e683e66319a68be6eb2cca60de" + "name": "hamcrest", + "purpose": "JUnit tests for validating the Java API", + "destination": "hamcrest-all-1.3.jar", + "url": "https://repo1.maven.org/maven2/org/hamcrest/hamcrest-all/1.3/hamcrest-all-1.3.jar", + "sha256sum": "4877670629ab96f34f5f90ab283125fcd9acb7e683e66319a68be6eb2cca60de" }, { - "name": "JUnit", - "purpose": "JUnit tests for validating the Java API", - "destination": "junit-4.13.2.jar", - "url": "https://repo1.maven.org/maven2/junit/junit/4.13.2/junit-4.13.2.jar", - "sha256sum": "8e495b634469d64fb8acfa3495a065cbacc8a0fff55ce1e31007be4c16dc57d3" + "name": "JUnit", + "purpose": "JUnit tests for validating the Java API", + "destination": "junit-4.13.2.jar", + "url": "https://repo1.maven.org/maven2/junit/junit/4.13.2/junit-4.13.2.jar", + "sha256sum": "8e495b634469d64fb8acfa3495a065cbacc8a0fff55ce1e31007be4c16dc57d3" } ] -} +} \ No newline at end of file diff --git a/src/cpp/Llm.cpp b/src/cpp/Llm.cpp index 5e337bb..38ff2c3 100644 --- a/src/cpp/Llm.cpp +++ b/src/cpp/Llm.cpp @@ -69,3 +69,8 @@ std::string LLM::BenchModel(int& nPrompts, int& nEvalPrompts, int& nMaxSeq, int& { return this->m_impl->BenchModel(nPrompts, nEvalPrompts, nMaxSeq, nRep); } + +std::string LLM::GetFrameworkType() +{ + return this->m_impl->GetFrameworkType(); +} diff --git a/src/cpp/LlmJni.cpp b/src/cpp/LlmJni.cpp index 4f595dd..662a6e1 100644 --- a/src/cpp/LlmJni.cpp +++ b/src/cpp/LlmJni.cpp @@ -8,7 +8,7 @@ #include -static std::unique_ptr llm{nullptr}; +static std::unique_ptr llm = std::make_unique(); #ifdef __cplusplus extern "C" { @@ -17,16 +17,22 @@ extern "C" { JNIEXPORT jlong JNICALL Java_com_arm_Llm_createLlmConfig(JNIEnv* env, jobject /* this */, jstring jModelTag, + jstring jUserTag, + jstring jEndTag, jstring jModelPath, jstring jLlmPrefix, jint jNumThreads, jint jBatchSize) { const char* modelTag = env->GetStringUTFChars(jModelTag, nullptr); + const char* userTag = env->GetStringUTFChars(jUserTag, nullptr); + const char* endTag = env->GetStringUTFChars(jEndTag, nullptr); const char* modelPath = env->GetStringUTFChars(jModelPath, nullptr); const char* llmPrefix = env->GetStringUTFChars(jLlmPrefix, nullptr); auto* config = new LlmConfig(std::string(modelTag), + std::string(userTag), + std::string(endTag), std::string(modelPath), std::string(llmPrefix), static_cast(jNumThreads), @@ -34,6 +40,8 @@ JNIEXPORT jlong JNICALL Java_com_arm_Llm_createLlmConfig(JNIEnv* env, // Clean up env->ReleaseStringUTFChars(jModelTag, modelTag); + env->ReleaseStringUTFChars(jUserTag, userTag); + env->ReleaseStringUTFChars(jEndTag, endTag); env->ReleaseStringUTFChars(jModelPath, modelPath); env->ReleaseStringUTFChars(jLlmPrefix, llmPrefix); @@ -43,7 +51,6 @@ JNIEXPORT jlong JNICALL Java_com_arm_Llm_createLlmConfig(JNIEnv* env, JNIEXPORT jlong JNICALL Java_com_arm_Llm_loadModel(JNIEnv* env, jobject, jlong pconfig) { auto config = reinterpret_cast(pconfig); - llm = std::make_unique(); llm->LlmInit(*config); return 0; } @@ -74,7 +81,6 @@ JNIEXPORT jfloat JNICALL Java_com_arm_Llm_getEncodeRate(JNIEnv* env, jobject) JNIEXPORT jfloat JNICALL Java_com_arm_Llm_getDecodeRate(JNIEnv* env, jobject) { - float result = llm->GetDecodeTimings(); return result; } @@ -83,14 +89,17 @@ JNIEXPORT void JNICALL Java_com_arm_Llm_resetTimings(JNIEnv* env, jobject) { llm->ResetTimings(); } + JNIEXPORT jsize JNICALL Java_com_arm_Llm_getChatProgress(JNIEnv* env, jobject) { return llm->GetChatProgress(); } + JNIEXPORT void JNICALL Java_com_arm_Llm_resetContext(JNIEnv* env, jobject) { llm->ResetContext(); } + JNIEXPORT jstring JNICALL Java_com_arm_Llm_benchModel( JNIEnv* env, jobject, jint nPrompts, jint nEvalPrompts, jint nMaxSeq, jint nRep) { @@ -98,6 +107,12 @@ JNIEXPORT jstring JNICALL Java_com_arm_Llm_benchModel( return env->NewStringUTF(result.c_str()); } +JNIEXPORT jstring JNICALL Java_com_arm_Llm_getFrameworkType(JNIEnv* env, jobject) +{ + std::string frameworkType = llm->GetFrameworkType(); + return env->NewStringUTF(frameworkType.c_str()); +} + #ifdef __cplusplus } #endif diff --git a/src/cpp/config/LlmConfig.cpp b/src/cpp/config/LlmConfig.cpp index db5f95f..f80c886 100644 --- a/src/cpp/config/LlmConfig.cpp +++ b/src/cpp/config/LlmConfig.cpp @@ -7,16 +7,28 @@ #include LlmConfig::LlmConfig(const std::string& modelTag, + const std::string& userTag, + const std::string& endTag, const std::string& modelPath, const std::string& llmPrefix, int numThreads, int batchSize) : - m_modelTag(modelTag), m_modelPath(modelPath), m_llmPrefix(llmPrefix) + m_modelTag(modelTag), m_userTag(userTag), m_endTag(endTag), m_modelPath(modelPath), m_llmPrefix(llmPrefix) { SetNumThreads(numThreads); SetBatchSize(batchSize); } +std::string LlmConfig::GetEndTag() const +{ + return this->m_endTag; +} + +std::string LlmConfig::GetUserTag() const +{ + return this->m_userTag; +} + std::string LlmConfig::GetModelTag() const { return this->m_modelTag; @@ -47,6 +59,16 @@ void LlmConfig::SetModelTag(const std::string& modelIdentifier) this->m_modelTag = modelIdentifier; } +void LlmConfig::SetUserTag(const std::string& userTag) +{ + this->m_userTag = userTag; +} + +void LlmConfig::SetEndTag(const std::string& endTag) +{ + this->m_endTag = endTag; +} + void LlmConfig::SetModelPath(const std::string& basePath) { this->m_modelPath = basePath; diff --git a/src/cpp/config/LlmConfig.hpp b/src/cpp/config/LlmConfig.hpp index 66713ab..2216d66 100644 --- a/src/cpp/config/LlmConfig.hpp +++ b/src/cpp/config/LlmConfig.hpp @@ -15,6 +15,8 @@ class LlmConfig { private: std::string m_modelTag{}; + std::string m_userTag{}; + std::string m_endTag{}; std::string m_modelPath{}; std::string m_llmPrefix{}; int m_numThreads{}; @@ -24,12 +26,16 @@ public: /** * LlmConfig * @param modelTag Model tag for the LLM model + * @param userTag User tag for the prompt + * @param endTag End tag for the prompt * @param modelPath Path to the model * @param llmPrefix LLM prefix to use * @param numThreads Number of threads to use * @param batchSize Batch size to use */ LlmConfig(const std::string& modelTag, + const std::string& userTag, + const std::string& endTag, const std::string& modelPath, const std::string& llmPrefix, int numThreads, @@ -37,6 +43,18 @@ public: LlmConfig() = default; + /** + * Returns the end tag string. + * @return endTag + */ + std::string GetEndTag() const; + + /** + * Returns the user tag string. + * @return userTag + */ + std::string GetUserTag() const; + /** * Returns the model tag string (The name to appear in conversation with the LLM). * @return modelTag @@ -74,6 +92,20 @@ public: */ void SetModelTag(const std::string& modelIdentifier); + /** + * Sets the user tag + * @param userTag is the user tag added at the beginning of each user question to make model + * respond appropriately + */ + void SetUserTag(const std::string& userTag); + + /** + * Sets the end tag + * @param endTag is the end tag added at the end of each user question to make model + * respond appropriately + */ + void SetEndTag(const std::string& endTag); + /** * Sets the file path to the model. * @param basePath absolute path to load llm model diff --git a/src/cpp/frameworks/CMakeLists.txt b/src/cpp/frameworks/CMakeLists.txt index 8416682..7a96bcf 100644 --- a/src/cpp/frameworks/CMakeLists.txt +++ b/src/cpp/frameworks/CMakeLists.txt @@ -5,8 +5,10 @@ # # Pull in LLM framework library -if (${LLM_DEP_NAME} STREQUAL "llama.cpp") +if (${LLM_FRAMEWORK} STREQUAL "llama.cpp") add_subdirectory(llama_cpp) +elseif(${LLM_FRAMEWORK} STREQUAL "onnxruntime-genai") + add_subdirectory(onnxruntime_genai) else() - message(FATAL_ERROR "${LLM_DEP_NAME} is currently not supported :(") + message(FATAL_ERROR "${LLM_FRAMEWORK} is currently not supported :(") endif() diff --git a/src/cpp/frameworks/llama_cpp/CMakeLists.txt b/src/cpp/frameworks/llama_cpp/CMakeLists.txt index 08eb37d..b6be17b 100644 --- a/src/cpp/frameworks/llama_cpp/CMakeLists.txt +++ b/src/cpp/frameworks/llama_cpp/CMakeLists.txt @@ -10,6 +10,7 @@ project(llama-cpp-wrapper LANGUAGES C CXX ASM) include(FetchContent) +include(check-flag) # Where should llama.cpp sources be cloned into? # It might make sense to download sources into resources as well and not @@ -29,22 +30,14 @@ set(LLAMA_GIT_SHA "a4090d1" "Git commit SHA for llama.cpp repo") set(LLAMA_BUILD_EXAMPLES ${BUILD_EXECUTABLE} CACHE BOOL "Build llama.cpp examples") +set(LLAMA_CURL OFF CACHE BOOL "llama: use libcurl to download model from an URL") -# If the user has NOT explicitly set GGML_CPU_KLEIDIAI -if (NOT DEFINED GGML_CPU_KLEIDIAI) - # if we are on arm64/aarch64, then default KleidiAI to ON. - if (CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm64|ARM64)$") - set(GGML_CPU_KLEIDIAI ON CACHE BOOL - "Enable KleidiAI by default on ${CMAKE_SYSTEM_PROCESSOR}") - message(STATUS "KleidiAI enabled by default") - # if we are NOT on arm64/aarch64, then default KleidiAI to OFF. - else() - set(GGML_CPU_KLEIDIAI OFF CACHE BOOL - "Disable KleidiAI by default on ${CMAKE_SYSTEM_PROCESSOR}") - message(STATUS "KleidiAI disabled by default") - endif() -else () - message(STATUS "KleidiAI: ${GGML_CPU_KLEIDIAI}") +# KleidiAI configuration +if(DEFINED GGML_CPU_KLEIDIAI) + message(FATAL_ERROR "Don't use framework specific KleidiAI flag, configure 'USE_KLEIDIAI' flag instead.") +else() + set_kleidiai_flag() + set(GGML_CPU_KLEIDIAI ${USE_KLEIDIAI}) endif() set(GGML_BUILD_NUMBER 1) # We do a shallow clone `--depth=1` diff --git a/src/cpp/frameworks/llama_cpp/LlamaImpl.cpp b/src/cpp/frameworks/llama_cpp/LlamaImpl.cpp index e5d9e21..ffc0089 100644 --- a/src/cpp/frameworks/llama_cpp/LlamaImpl.cpp +++ b/src/cpp/frameworks/llama_cpp/LlamaImpl.cpp @@ -288,6 +288,11 @@ size_t LLM::LLMImpl::GetChatProgress() const return this->m_contextFilled; } +std::string LLM::LLMImpl::GetFrameworkType() +{ + return this->m_frameworkType; +} + std::string LLM::LLMImpl::BenchModel(int& prompts, int& eval_prompts, int& n_max_sq, int& n_rep) { auto prompts_avg = 0.0; @@ -435,3 +440,4 @@ static bool is_valid_utf8(const char* string) } return true; } + diff --git a/src/cpp/frameworks/llama_cpp/LlmImpl.hpp b/src/cpp/frameworks/llama_cpp/LlmImpl.hpp index 1d3f027..b969f18 100644 --- a/src/cpp/frameworks/llama_cpp/LlmImpl.hpp +++ b/src/cpp/frameworks/llama_cpp/LlmImpl.hpp @@ -28,13 +28,13 @@ public: ~LLMImpl(); /** - * Method to Initialize a llama_model + * Method to initialize a llama_model * @param config Configuration class with model's parameter and user defined parameters */ void LlmInit(const LlmConfig& config); /** - * Method to Free all allocations pertaining to llama model + * Method to free all allocations pertaining to llama model */ void FreeLlm(); @@ -62,7 +62,7 @@ public: std::string SystemInfo(); /** - * Method to reset Conversation history and preserve Model's character prefix. + * Method to reset conversation history and preserve model's character prefix. * If model's prefix is not defined all conversation history would be cleared */ void ResetContext(); @@ -75,18 +75,18 @@ public: /** * Method to wrap CompletionLoop function - * @return the next Token for Encoded Prompt + * @return the next token for encoded prompt */ std::string NextToken(); /** - * @brief The Method return the percentage of chat context filled + * The Method return the percentage of chat context filled * @return chat capacity filled in cache as percentage number */ size_t GetChatProgress() const; /** - * @brief Benchmarks the performance of the LLM model. + * Benchmarks the performance of the LLM model. * * This function evaluates the model's performance by processing a specified number of prompts * and generating text sequences. It measures the speed of prompt evaluation and text @@ -101,10 +101,16 @@ public: * size, number of parameters, backend information, and performance metrics for prompt * evaluation and text generation. */ - std::string BenchModel(int& prompts, int& eval_prompts, int& n_max_sq, int& n_rep); + /** + * Method to get framework type + * @return string framework type + */ + std::string GetFrameworkType(); + private: + std::string m_frameworkType{"llama.cpp"}; llama_context* m_llmContext{nullptr}; llama_model* m_llmModel{nullptr}; llama_batch m_llmBatch{}; diff --git a/src/cpp/frameworks/onnxruntime_genai/CMakeLists.txt b/src/cpp/frameworks/onnxruntime_genai/CMakeLists.txt new file mode 100644 index 0000000..0528e41 --- /dev/null +++ b/src/cpp/frameworks/onnxruntime_genai/CMakeLists.txt @@ -0,0 +1,177 @@ +# +# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +# +# SPDX-License-Identifier: Apache-2.0 +# + +# Declare project +project(onnxruntime-genai-wrapper + DESCRIPTION "ONNX Runtime GenAI wrapper interface implementation" + LANGUAGES C CXX ASM) + +include(FetchContent) +include(check-flag) + +# -------------------------------------------------------------------- +# onnxruntime build +# -------------------------------------------------------------------- + +# KleidiAI configuration +if(DEFINED onnxruntime_USE_KLEIDIAI) + message(FATAL_ERROR "Don't use framework specific KleidiAI flag, configure 'USE_KLEIDIAI' flag instead.") +else() + set_kleidiai_flag() + set(onnxruntime_USE_KLEIDIAI ${USE_KLEIDIAI}) +endif() + +# -------------------------------------------------------------------- +# onnxruntime build flags +# onnxruntime_BUILD_SHARED_LIB needs to be ON to produce libonnxruntime.so +# which is a dependency for libonnruntime-genai.so +# -------------------------------------------------------------------- + +set(onnxruntime_BUILD_UNIT_TESTS OFF) +set(onnxruntime_ENABLE_TESTS OFF) +set(onnxruntime_DISABLE_RTTI ON) +set(onnxruntime_DISABLE_FLOAT8_TYPES ON) +set(onnxruntime_ENABLE_EXTERNAL_CUSTOM_OP_SCHEMAS OFF) +set(onnxruntime_USE_MIMALLOC OFF) +set(onnxruntime_BUILD_SHARED_LIB ON) +set(ONNX_BUILD_SHARED_LIBS OFF) + +set(ONNXRUNTIME_GIT_URL "https://github.com/microsoft/onnxruntime.git" + CACHE STRING "Git URL for onnxruntime repo") + +set(ONNXRUNTIME_GIT_TAG "v1.22.0" + CACHE STRING "Path where onnxruntime repo should be cloned") + +set(ONNXRUNTIME_SRC_DIR "${CMAKE_BINARY_DIR}/onnxruntime" + CACHE PATH "Source dir") + +set(ONNXRUNTIME_BUILD_DIR "${CMAKE_BINARY_DIR}/ort_build" + CACHE PATH "Onnxruntime build dir") + +file(MAKE_DIRECTORY "${ONNXRUNTIME_BUILD_DIR}") + +FetchContent_Declare( + onnxruntime + GIT_REPOSITORY ${ONNXRUNTIME_GIT_URL} + GIT_TAG ${ONNXRUNTIME_GIT_TAG} + SOURCE_DIR ${ONNXRUNTIME_SRC_DIR} + GIT_SHALLOW 1 +) + +FetchContent_MakeAvailable(onnxruntime) + +add_subdirectory( + ${ONNXRUNTIME_SRC_DIR}/cmake + ${ONNXRUNTIME_BUILD_DIR}) + +# --------------------------------------------------------------------------- +# Temporary workaround: +# +# The onnxruntime-genai build expects the onnxruntime header files and the shared +# library to reside in a very specific folder structure at configuration +# time. +# +# To build onnxruntime from source (avoiding the pre-built download), we create a +# temporary staging directory that: +# - Copies in the onnxruntime headers required by onnxruntime-genai, and +# - Drops in a placeholder onnxruntime shared library just long enough for +# onnxruntime-genai to detect it (the file is removed immediately afterward). +# +# The actual onnxruntime shared library produced by the build is written to +# ${CMAKE_BINARY_DIR}/lib. +# --------------------------------------------------------------------------- + +set(ORT_HOME "${ONNXRUNTIME_BUILD_DIR}/ort-tmp-src" CACHE STRING "Ort path") + +if(ANDROID) + set(ORT_INCLUDE_DIR "${ORT_HOME}/headers") + set(ORT_LIB_DIR "${ORT_HOME}/jni/${ANDROID_ABI}") +else() + set(ORT_INCLUDE_DIR "${ORT_HOME}/include") + set(ORT_LIB_DIR "${ORT_HOME}/lib") +endif() + +file(MAKE_DIRECTORY "${ORT_INCLUDE_DIR}") +file(MAKE_DIRECTORY "${ORT_LIB_DIR}") + +set(ORT_HEADERS_DIR "${ONNXRUNTIME_SRC_DIR}/include/onnxruntime/core/session/") + +file(COPY ${ORT_HEADERS_DIR} DESTINATION ${ORT_INCLUDE_DIR}) +file(TOUCH ${ORT_LIB_DIR}/libonnxruntime.so "") +file(TOUCH ${ORT_LIB_DIR}/libonnxruntime.dylib "") + +# -------------------------------------------------------------------- +# onnxruntime-genai build +# -------------------------------------------------------------------- + +set(ONNXRT_GENAI_SRC_DIR "${CMAKE_BINARY_DIR}/onnxruntime-genai" + CACHE PATH "Path where onnxruntime-genai repo should be cloned") + +set(ONNXRT_GENAI_GIT_URL "https://github.com/microsoft/onnxruntime-genai.git" + CACHE STRING "Git URL for onnxruntime-genai repo") + +# Latest stable tag (June 2025) +set(ONNXRT_GENAI_GIT_TAG "v0.8.2" + CACHE STRING "Git tag / commit SHA for onnxruntime-genai repo") + +# -------------------------------------------------------------------- +# onnxruntime-genai build flags +# Disable GPU execution providers unless the caller turns them back on +# (e.g. cmake .. -DUSE_CUDA=ON). +# -------------------------------------------------------------------- + +set(USE_CUDA OFF) +set(USE_ROCM OFF) +set(USE_DML OFF) +set(ENABLE_PYTHON OFF) +set(ENABLE_JAVA OFF) +set(ENABLE_TESTS OFF) + +# model-benchmark standalone application +set(ENABLE_MODEL_BENCHMARK ${BUILD_EXECUTABLE} CACHE BOOL "Enable model benchmark binary") + +# Fetch the dependency Git repo here +FetchContent_Declare(onnxruntime-genai + GIT_REPOSITORY ${ONNXRT_GENAI_GIT_URL} + GIT_TAG ${ONNXRT_GENAI_GIT_TAG} + GIT_SHALLOW 1 + SOURCE_DIR ${ONNXRT_GENAI_SRC_DIR}) + +FetchContent_MakeAvailable(onnxruntime-genai) + +# Remove the created onnxruntime shared library +file(REMOVE_RECURSE ${ORT_LIB_DIR}) + +# Check internally defined dependency targets are visible +if (NOT TARGET arm-llm-config) + message(FATAL_ERROR "arm-llm-config target not defined") +elseif (NOT TARGET arm-llm-interface) + message(FATAL_ERROR "arm-llm-interface target not defined") +endif() + +add_library(arm-llm-framework STATIC + ${CMAKE_CURRENT_SOURCE_DIR}/OnnxrtImpl.cpp) + +target_include_directories(arm-llm-framework PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR} + ${ONNXRT_GENAI_SRC_DIR}/src) + +# List all libraries that we need to depend on here: +target_link_libraries(arm-llm-framework PUBLIC + onnxruntime-genai + onnxruntime + arm-llm-config + arm-llm-interface) + +if(ENABLE_MODEL_BENCHMARK) + target_include_directories(model_benchmark PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} + ${ONNXRT_GENAI_SRC_DIR}/src) + + target_link_libraries(model_benchmark PRIVATE onnxruntime-genai ${ONNXRUNTIME_LIB}) + target_link_directories(model_benchmark PRIVATE ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}) + add_dependencies(model_benchmark onnxruntime) +endif() diff --git a/src/cpp/frameworks/onnxruntime_genai/LlmImpl.hpp b/src/cpp/frameworks/onnxruntime_genai/LlmImpl.hpp new file mode 100644 index 0000000..0588f13 --- /dev/null +++ b/src/cpp/frameworks/onnxruntime_genai/LlmImpl.hpp @@ -0,0 +1,204 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#ifndef LLM_IMPL_HPP +#define LLM_IMPL_HPP + +#include "Llm.hpp" +#include "LlmConfig.hpp" + +#include "ort_genai.h" + +/* Forward declaration */ +class LLM; + +/** + * @brief ONNX Implementation of our LLM API + */ +class LLM::LLMImpl { + +public: + LLMImpl(); + ~LLMImpl(); + + /** + * Method to initialize a ONNX model + * @param config Configuration class with model's parameter and user defined parameters + */ + void LlmInit(const LlmConfig& config); + + /** + * Method to free all allocations pertaining to ONNX model + */ + void FreeLlm(); + + /** + * Function to retrieve the ONNX encode timings. + * @return The encoded tokens per second + */ + float GetEncodeTimings(); + + /** + * Function to retrieve the ONNX decode timings. + * @return The decoded tokens per second + */ + float GetDecodeTimings(); + + /** + * Function to reset the ONNX timing + */ + void ResetTimings(); + + /** + * Function to print the system info + * @return System info as a char pointer + */ + std::string SystemInfo(); + + /** + * Method to reset conversation history and preserve model's character prefix. + * If model's prefix is not defined all conversation history would be cleared + */ + void ResetContext(); + + /** + * Method to prompt encoding + * @param prompt Query to LLM + */ + void Encode(std::string& prompt); + + /** + * Method to produce next token + * @return the next token for encoded prompt + */ + std::string NextToken(); + + /** + * The method return the percentage of chat context filled + * @return chat capacity filled in cache as percentage number + */ + size_t GetChatProgress() const; + + /** + * Benchmarks the performance of the LLM model. + * + * This function evaluates the model's performance by processing a specified number of prompts + * and generating text sequences. It measures the speed of prompt evaluation and text + * generation, calculates average speeds and standard deviations over multiple repetitions, and + * compiles the results into a formatted string. + * + * @param prompts Number of prompts to process during benchmarking. + * @param eval_prompts Number of evaluation prompts for text generation. + * @param n_max_sq Maximum sequence length for text generation. + * @param n_rep Number of repetitions for benchmarking to obtain average metrics. + * @return A formatted string containing the benchmark results, including model description, + * size, number of parameters, backend information, and performance metrics for prompt + * evaluation and text generation. + */ + std::string BenchModel(int& prompts, int& eval_prompts, int& n_max_sq, int& n_rep); + + /** + * Method to get framework type + * @return string framework type + */ + std::string GetFrameworkType(); + +private: + // Framework type + std::string m_frameworkType{"onnxruntime-genai"}; + // Pointer to the loaded OgaModel used for inference + std::unique_ptr m_llmModelPtr {nullptr}; + // Pointer to the OgaConfig instance containing model configuration settings. + std::unique_ptr m_llmConfigsPtr {nullptr}; + // Pointer to the OgaGeneratorParams instance holding generation parameters. + std::unique_ptr m_llmGntParamsPtr {nullptr}; + // Pointer to the OgaGenerator object responsible for text generation. + std::unique_ptr m_llmGeneratorPtr {nullptr}; + // Pointer to the OgaTokenizer used for tokenizing input text. + std::unique_ptr m_tokenizerPtr {nullptr}; + // Pointer to the OgaTokenizerStream used for streaming tokenized outputs. + std::unique_ptr m_tokenizerStreamPtr {nullptr}; + // Pointer to the OgaSequences container storing generated token sequences. + std::unique_ptr m_sequencesPtr {nullptr}; + + // Number of threads to use for model inference. + size_t m_numOfThreads{0}; + // Maximum context length (number of tokens) supported by the model. + int m_nCtx{2048}; + // Batch size for token generation operations. + size_t m_batchSz{0}; + // Filesystem path to the ONNX model. + std::string m_modelPath{""}; + // Indicates whether the LLM has been initialized. + bool m_llmInitialized{false}; + // Number of tokens currently filled in the context window + size_t m_contextFilled{0}; + // Prefix text prepended to each generation request. + std::string m_llmPrefix{""}; + // Flag indicating if the context window has been reset. + bool m_ctxResetted = false; + // Total number of decoded tokens + size_t m_totalDecodedTokens = 0; + // Total number of encoded tokens + size_t m_totalEncodedTokens = 0; + // Total time for decoder + double m_totalDecoderTime = 0.0; + // Total time for encoder + double m_totalEncoderTime = 0.0; + + + + /** + * Function to initialize the LLM model sequence + */ + void InitSequence(); + + /** + * Frees the memory holding the LLM Model sequence + */ + void FreeSequence(); + + /** + * Function to initialize the LLM model configs + */ + void InitConfigs(); + + /** + * Frees the memory holding the configs + */ + void FreeConfigs(); + + /** + * Function to initialize a new generator + */ + void InitGenerator(); + + /** + * Frees the memory holding the generator + */ + void FreeGenerator(); + + /** + * Function to initialize a new tokenizer + */ + void InitTokenizer(); + + /** + * Frees the memory holding the tokenizer + */ + void FreeTokenizer(); + + /** + * Function to load the chosen ONNX model to memory + */ + void LoadModel(); + + /** + * Frees the memory holding the ONNX model + */ + void FreeModel(); +}; + +#endif /* LLM_IMPL_HPP */ diff --git a/src/cpp/frameworks/onnxruntime_genai/OnnxrtImpl.cpp b/src/cpp/frameworks/onnxruntime_genai/OnnxrtImpl.cpp new file mode 100644 index 0000000..e4bcc9d --- /dev/null +++ b/src/cpp/frameworks/onnxruntime_genai/OnnxrtImpl.cpp @@ -0,0 +1,333 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#include "LlmImpl.hpp" +#include + +#define LOG_INF(...) \ + do { \ + fprintf(stdout, __VA_ARGS__); \ + } while (0) + +using Clock = std::chrono::high_resolution_clock; +using TimePoint = std::chrono::time_point; +using Duration = std::chrono::duration; + +/** + * @brief ONNX Implementation of our LLM API + * + */ +LLM::LLMImpl::LLMImpl() {} + +LLM::LLMImpl::~LLMImpl() +{ + this->FreeLlm(); +} + +void LLM::LLMImpl::InitSequence() +{ + this->m_sequencesPtr = OgaSequences::Create(); + + if (this->m_sequencesPtr == nullptr) { + throw std::runtime_error("Error: unable to init sequence"); + } + + LOG_INF("Seuqence Initialized\n"); +} + +void LLM::LLMImpl::FreeSequence() +{ + if (this->m_sequencesPtr) { + this->m_sequencesPtr.reset(); + this->m_sequencesPtr = nullptr; + LOG_INF("Freed Sequences\n"); + } +} + +void LLM::LLMImpl::InitConfigs() +{ + // genai_config.json path (same as model path) + this->m_llmConfigsPtr = OgaConfig::Create(this->m_modelPath.c_str()); + + if (this->m_llmConfigsPtr == nullptr) { + throw std::runtime_error("Error: configs initialization failed"); + } + + // This will fall back to default provider which is: CPU + this->m_llmConfigsPtr->ClearProviders(); + + // Currently we modify only thread numbers, but different session options can be modified + // Ref: https://github.com/microsoft/onnxruntime-genai/blob/79d1d8470b74564fc4e723312a476e692057b600/src/config.h#L64 + std::string patch = + std::string(R"json({ + "model": { + "decoder": { + "session_options": { + "intra_op_num_threads": )json") + + std::to_string(this->m_numOfThreads) + + R"json( + } + } + } + })json"; + + this->m_llmConfigsPtr->Overlay(patch.c_str()); + LOG_INF("Configs Initialized\n"); +} + +void LLM::LLMImpl::FreeConfigs() +{ + if (this->m_llmConfigsPtr) { + this->m_llmConfigsPtr.reset(); + this->m_llmConfigsPtr = nullptr; + LOG_INF("Freed Configs\n"); + } +} + +void LLM::LLMImpl::InitGenerator() +{ + this->m_llmGntParamsPtr = OgaGeneratorParams::Create(* this->m_llmModelPtr); + + if (this->m_llmGntParamsPtr == nullptr) { + throw std::runtime_error("Error: generator params initialization failed"); + } + + this->m_llmGntParamsPtr->SetSearchOption("max_length", this->m_nCtx); + + this->m_llmGeneratorPtr = OgaGenerator::Create(* this->m_llmModelPtr, * this->m_llmGntParamsPtr); + + if (this->m_llmGeneratorPtr == nullptr) { + throw std::runtime_error("Error: generator initialization failed. Unable to create ONNX generator"); + } + + LOG_INF("Generator Initialized\n"); +} + +void LLM::LLMImpl::FreeGenerator() +{ + if (this->m_llmGeneratorPtr) { + this->m_llmGntParamsPtr.reset(); + this->m_llmGntParamsPtr = nullptr; + + this->m_llmGeneratorPtr.reset(); + this->m_llmGeneratorPtr = nullptr; + LOG_INF("Freed Generator\n"); + } +} + +void LLM::LLMImpl::InitTokenizer() +{ + this->m_tokenizerPtr = OgaTokenizer::Create(*this->m_llmModelPtr); + + if (this->m_tokenizerPtr == nullptr) { + throw std::runtime_error("Error: tokenizer initialization failed"); + } + + this->m_tokenizerStreamPtr = OgaTokenizerStream::Create(*this->m_tokenizerPtr); + + if (this->m_tokenizerStreamPtr == nullptr) { + throw std::runtime_error("Error: tokenizer stream initialization failed"); + } + + LOG_INF("Tokenizer Initialized\n"); + LOG_INF("Tokenizer Stream Initialized\n"); +} + +void LLM::LLMImpl::FreeTokenizer() +{ + if (this->m_tokenizerPtr) { + this->m_tokenizerPtr.reset(); + this->m_tokenizerPtr = nullptr; + + this->m_tokenizerStreamPtr.reset(); + this->m_tokenizerStreamPtr = nullptr; + LOG_INF("Freed Tokenizer\n"); + } +} + +void LLM::LLMImpl::LoadModel() +{ + this->m_llmModelPtr = OgaModel::Create(* this->m_llmConfigsPtr); + + if (this->m_llmModelPtr == nullptr) { + throw std::runtime_error("Error: unable to load model from " + std::string(this->m_modelPath)); + } + + LOG_INF("Model Loaded\n"); +} + +void LLM::LLMImpl::FreeModel() +{ + if (this->m_llmModelPtr) { + this->m_llmModelPtr.reset(); + this->m_llmModelPtr = nullptr; + LOG_INF("Freed Model\n"); + } + + this->m_llmInitialized = false; +} + +void LLM::LLMImpl::LlmInit(const LlmConfig& config) +{ + try { + this->m_batchSz = config.GetBatchSize(); + this->m_numOfThreads = config.GetNumThreads(); + this->m_modelPath = config.GetModelPath().c_str(); + this->m_llmPrefix = config.GetLlmPrefix(); + + InitConfigs(); + + if (this->m_llmConfigsPtr != nullptr) { + LoadModel(); + } + else { + LOG_INF("Config is not initialized\n"); + } + + if (this->m_llmModelPtr != nullptr) { + InitTokenizer(); + InitGenerator(); + } + + else { + LOG_INF("Model is not loaded\n"); + } + + if (this->m_llmConfigsPtr != nullptr && + this->m_tokenizerStreamPtr != nullptr && + this->m_llmGeneratorPtr != nullptr) { + + this->m_llmInitialized = true; + } + + else { + this->m_llmInitialized = false; + } + + } catch (const std::exception& e) { + throw std::runtime_error("LLM initialization failed: " + std::string(e.what())); + } + + LOG_INF("LLM Initialized\n"); +} + +void LLM::LLMImpl::FreeLlm() +{ + if (this->m_llmInitialized) { + FreeConfigs(); + FreeModel(); + FreeGenerator(); + FreeTokenizer(); + FreeSequence(); + ResetTimings(); + this->m_llmInitialized = false; + LOG_INF("Freed Entire LLM\n"); + } +} + +void LLM::LLMImpl::ResetContext() +{ + this->m_llmGeneratorPtr->RewindTo(0); + this->m_ctxResetted = true; + LOG_INF("Reset Context\n"); +} + +void LLM::LLMImpl::Encode(std::string& prompt) +{ + if (this->m_ctxResetted) { + prompt = this->m_llmPrefix + prompt; + this->m_ctxResetted = false; + } + + // Time start + TimePoint startTimeStampEncoder = Clock::now(); + + InitSequence(); + + this->m_tokenizerPtr->Encode(prompt.c_str(), * this->m_sequencesPtr); + this->m_llmGeneratorPtr->AppendTokenSequences(* this->m_sequencesPtr); + + // Record finishing time + this->m_totalEncoderTime += Duration(Clock::now() - startTimeStampEncoder).count(); + this->m_totalEncodedTokens += this->m_sequencesPtr->SequenceCount(0); + +} + +std::string LLM::LLMImpl::NextToken() +{ + if(!this->m_llmGeneratorPtr->IsDone()) { + // Record starting time + TimePoint startTimeStampDecoder = Clock::now(); + + this->m_llmGeneratorPtr->GenerateNextToken(); + size_t cnt = this->m_llmGeneratorPtr->GetSequenceCount(0); + int32_t tok = this->m_llmGeneratorPtr->GetSequenceData(0)[cnt - 1]; + auto out = this->m_tokenizerStreamPtr->Decode(tok); + + // Record finishing time + this->m_totalDecoderTime += Duration(Clock::now() - startTimeStampDecoder).count(); + this->m_totalDecodedTokens += 1; + + size_t nCurr = this->m_llmGeneratorPtr->GetSequenceCount(0); + + this->m_contextFilled = 100 * nCurr / this->m_nCtx; + + return out; + } + + else { + return "<|endoftext|>"; + } +} + +size_t LLM::LLMImpl::GetChatProgress() const +{ + return this->m_contextFilled; +} + +float LLM::LLMImpl::GetEncodeTimings() +{ + auto encoderTPS = this->m_totalEncodedTokens / this->m_totalEncoderTime; + return encoderTPS; +} + +float LLM::LLMImpl::GetDecodeTimings() +{ + auto decoderTPS = this->m_totalDecodedTokens / this->m_totalDecoderTime; + return decoderTPS; +} + +void LLM::LLMImpl::ResetTimings() +{ + this->m_totalDecoderTime = 0; + this->m_totalEncoderTime = 0; + this->m_totalDecodedTokens = 0; + this->m_totalEncodedTokens = 0; + LOG_INF("Reset Timings\n"); + +} + +std::string LLM::LLMImpl::SystemInfo() +{ + std::string sysInfo = "\nSystem INFO:\n"; + std::string deviceType = std::string(this->m_llmModelPtr->GetDeviceType()); + std::string modelType = std::string(this->m_llmModelPtr->GetType()); + sysInfo += "Device Type: " + deviceType + "\n"; + sysInfo += "Model Type: " + modelType + "\n"; + return sysInfo; +} + +std::string LLM::LLMImpl::BenchModel(int& prompts, int& eval_prompts, int& n_max_sq, int& n_rep) +{ + // TODO: Refactor BenchModel() into a framework-agnostic utility: + // Abstract the core benchmarking logic into a shared BenchModel(const Config&) function, + // Migrate each framework submodule to invoke it, and consolidate all parameters into the Config struct. + return (char *) nullptr; +} + +std::string LLM::LLMImpl::GetFrameworkType() +{ + return this->m_frameworkType; +} diff --git a/src/cpp/interface/Llm.hpp b/src/cpp/interface/Llm.hpp index 5355551..f7a5a22 100644 --- a/src/cpp/interface/Llm.hpp +++ b/src/cpp/interface/Llm.hpp @@ -91,6 +91,12 @@ public: */ std::string BenchModel(int& nPrompts, int& nEvalPrompts, int& nMaxSeq, int& nRep); + + /** + * Method to get framework type + * @return string framework type + */ + std::string GetFrameworkType(); }; #endif /* ARM_LLM_HPP */ diff --git a/src/java/com/arm/Llm.java b/src/java/com/arm/Llm.java index 5c87432..a994487 100644 --- a/src/java/com/arm/Llm.java +++ b/src/java/com/arm/Llm.java @@ -31,6 +31,7 @@ public class Llm extends SubmissionPublisher private long llmPtr = 0; private String modelTag = ""; private String userTag = ""; + private String endTag = ""; private List stopWords = null; private String cachedToken = ""; private String emitToken = ""; @@ -43,14 +44,17 @@ public class Llm extends SubmissionPublisher /** * Method to create LlmConfig cpp instance from params. * @param modelTag name used to refer the model + * @param userTag tag used to refer the user + * @param endTag tag to specify the end of the query * @param modelPath path to load model from * @param llmPrefix Initial-prompt to load into llm before query * @param numThreads Number of threads for inference * @param batchSize batch size used to chunk queries * @return pointer to llm config */ - public native long createLlmConfig(String modelTag, String modelPath, String llmPrefix, - int numThreads, int batchSize); + public native long createLlmConfig(String modelTag, String userTag, String endTag, + String modelPath, String llmPrefix, int numThreads, + int batchSize); /** * Method for loading LLM model * @param LlmConfig load model from LlmConfig @@ -107,7 +111,6 @@ public class Llm extends SubmissionPublisher /** * Method to decode answers one by one, once prefill stage is completed - * * @param nPrompts prompt length used for benchmarking * @param nEvalPrompts number of generated tokens for benchmarking * @param nMaxSeq sequence number @@ -121,6 +124,12 @@ public class Llm extends SubmissionPublisher int nRep ); + /** + * Method to get framework type + * @return string framework type + */ + public native String getFrameworkType(); + /** * Method to separate Initialization from constructor * @param llmConfig type configuration file to load model @@ -130,10 +139,12 @@ public class Llm extends SubmissionPublisher this.stopWords = llmConfig.getStopWords(); this.modelTag = llmConfig.getModelTag(); this.userTag = llmConfig.getUserTag(); + this.endTag = llmConfig.getEndTag(); this.llmPrefix = llmConfig.getLlmPrefix(); this.numThreads = llmConfig.getNumThreads(); - long configPtr = createLlmConfig(this.modelTag,llmConfig.getModelPath(), - this.llmPrefix,this.numThreads,this.batchSize); + long configPtr = createLlmConfig(this.modelTag, this.userTag, this.endTag, + llmConfig.getModelPath(), this.llmPrefix, + this.numThreads, this.batchSize); this.llmPtr = loadModel(configPtr); } @@ -143,7 +154,6 @@ public class Llm extends SubmissionPublisher */ public void setSubscriber(Flow.Subscriber subscriber) { - System.out.println("subscribed set from llama"); this.subscribe(subscriber); } @@ -156,9 +166,9 @@ public class Llm extends SubmissionPublisher String query = ""; AtomicBoolean stop = new AtomicBoolean(false); if (evaluatedOnce.get()) - query = userTag + Query + modelTag; + query = userTag + Query + endTag + modelTag; else - query = llmPrefix + Query + modelTag; + query = llmPrefix + userTag + Query + endTag + modelTag; encode(query); evaluatedOnce.set(true); while (getChatProgress()<100) @@ -189,9 +199,9 @@ public class Llm extends SubmissionPublisher String query = ""; boolean stop = false; if (evaluatedOnce.get()) - query = userTag + Query + modelTag; + query = userTag + Query + endTag + modelTag; else - query = llmPrefix + Query + modelTag; + query = llmPrefix + userTag + Query + endTag + modelTag; encode(query); evaluatedOnce.set(true); while (getChatProgress()<100) diff --git a/src/java/com/arm/LlmConfig.java b/src/java/com/arm/LlmConfig.java index 169b350..1897bbd 100644 --- a/src/java/com/arm/LlmConfig.java +++ b/src/java/com/arm/LlmConfig.java @@ -15,13 +15,14 @@ public class LlmConfig { private String modelTag; private String userTag; + private String endTag; private String modelPath; private String llmPrefix; private List stopWords; private int numThreads; /** - * Minimal constructor without userTag and numThreads + * Minimal constructor without userTag, endTag and numThreads * * @param modelTag tag for the model * @param stopWords stop words to use @@ -30,7 +31,7 @@ public class LlmConfig */ public LlmConfig(String modelTag, List stopWords, String modelPath, String llmPrefix) { - this(modelTag, stopWords, modelPath, llmPrefix, "", 4); + this(modelTag, stopWords, modelPath, llmPrefix, "", "", 4); } /** @@ -41,15 +42,16 @@ public class LlmConfig * @param modelPath path to the model * @param llmPrefix llm prefix to use * @param userTag user tag to use + * @param endTag end tag to use */ - public LlmConfig(String modelTag, List stopWords, String modelPath, String llmPrefix, String userTag) + public LlmConfig(String modelTag, List stopWords, String modelPath, String llmPrefix, String userTag, String endTag) { // Use 4 threads by default - this(modelTag, stopWords, modelPath, llmPrefix, userTag, 4); + this(modelTag, stopWords, modelPath, llmPrefix, userTag, endTag, 4); } /** - * Minimal constructor without userTag + * Minimal constructor without userTag, and endTag * * @param modelTag tag for the model * @param stopWords stop words to use @@ -57,9 +59,9 @@ public class LlmConfig * @param llmPrefix llm prefix to use * @param numThreads number of threads to use */ - public LlmConfig(String modelTag, List stopWords, String modelPath, String llmPrefix,int numThreads) + public LlmConfig(String modelTag, List stopWords, String modelPath, String llmPrefix, int numThreads) { - this(modelTag, stopWords, modelPath, llmPrefix, "", numThreads); + this(modelTag, stopWords, modelPath, llmPrefix, "", "", numThreads); } /** @@ -70,16 +72,18 @@ public class LlmConfig * @param modelPath path to the model * @param llmPrefix llm prefix to use * @param userTag user tag to use + * @param endTag end tag to use * @param numThreads number of threads to use */ public LlmConfig(String modelTag, List stopWords, String modelPath, - String llmPrefix, String userTag, int numThreads) + String llmPrefix, String userTag, String endTag, int numThreads) { this.modelTag = modelTag; this.stopWords = stopWords; this.modelPath = modelPath; this.llmPrefix = llmPrefix; this.userTag = userTag; + this.endTag = endTag; this.numThreads = numThreads; } @@ -101,6 +105,15 @@ public class LlmConfig { return this.userTag; } + /** + * Gets the end tag. + * + * @return The end tag. + */ + public String getEndTag() + { + return this.endTag; + } /** * Gets the list of stop words. @@ -161,6 +174,16 @@ public class LlmConfig this.userTag = userTag; } + /** + * Sets the end tag. + * + * @param endTag The end tag to set. + */ + public void setEndTag(String endTag) + { + this.endTag = endTag; + } + /** * Sets the list of stop words. * @@ -192,10 +215,10 @@ public class LlmConfig } /** - * Sets the number of Threads. - * @param numThreads count of threads to use for LLM. - */ - + * Sets the number of Threads. + * + * @param numThreads count of threads to use for LLM. + */ public void setNumThreads(int numThreads) { this.numThreads = numThreads; diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 54886b2..31f2099 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -25,8 +25,10 @@ add_executable(llm-cpp-tests ${CMAKE_CURRENT_SOURCE_DIR}/cpp/LlmUtils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/LlmTest.cpp) -if (${LLM_DEP_NAME} STREQUAL "llama.cpp") +if (${LLM_FRAMEWORK} STREQUAL "llama.cpp") set(CONFIG_FILE_NAME "llamaConfig.txt" CACHE STRING "Path to the Llama config file") +elseif (${LLM_FRAMEWORK} STREQUAL "onnxruntime-genai") + set(CONFIG_FILE_NAME "onnxrtConfig.txt" CACHE STRING "Path to the ONNX Runtime GenAI config file") endif () diff --git a/test/cpp/LlmTest.cpp b/test/cpp/LlmTest.cpp index d0f520b..baaa897 100644 --- a/test/cpp/LlmTest.cpp +++ b/test/cpp/LlmTest.cpp @@ -49,7 +49,8 @@ TEST_CASE("Test Llm-Wrapper class") SetupTestConfig(stopWordsStream, &configTest, STOP_WORDS); std::string response; - std::string question = "What is the capital of France?" + configTest.GetModelTag(); + std::string question = configTest.GetUserTag() +"What is the capital of France?" + + configTest.GetEndTag() + configTest.GetModelTag(); std::string prefixedQuestion = configTest.GetLlmPrefix() + question; LLM llm; diff --git a/test/cpp/LlmUtils.cpp b/test/cpp/LlmUtils.cpp index 3d90bae..96acd00 100644 --- a/test/cpp/LlmUtils.cpp +++ b/test/cpp/LlmUtils.cpp @@ -79,6 +79,10 @@ LlmConfig GetConfig(std::unordered_map config, throw std::runtime_error("Missing required parameter: modelPath"); if (config.find("modelTag") == config.end()) throw std::runtime_error("Missing required parameter: modelTag"); + if (config.find("userTag") == config.end()) + throw std::runtime_error("Missing required parameter: userTag"); + if (config.find("endTag") == config.end()) + throw std::runtime_error("Missing required parameter: endTag"); if (config.find("llmPrefix") == config.end()) throw std::runtime_error("Missing required parameter: llmPrefix"); @@ -90,6 +94,8 @@ LlmConfig GetConfig(std::unordered_map config, throw std::runtime_error("Missing required parameter: stopWords"); return LlmConfig(config.at("modelTag"), + config.at("userTag"), + config.at("endTag"), config.at("modelPath"), config.at("llmPrefix"), userConfig.at("numThreads"), diff --git a/test/java/com/arm/LlmTestJNI.java b/test/java/com/arm/LlmTestJNI.java index f1e5f4a..4886579 100644 --- a/test/java/com/arm/LlmTestJNI.java +++ b/test/java/com/arm/LlmTestJNI.java @@ -27,6 +27,7 @@ public class LlmTestJNI { private static int numThreads = 4; private static String modelTag = ""; private static String userTag = ""; + private static String endTag = ""; private static String modelPath = ""; private static String llmPrefix = ""; private static List stopWords = new ArrayList(); @@ -80,6 +81,7 @@ public class LlmTestJNI { loadVariables(configFilePath); modelTag = variables.get("modelTag"); userTag = variables.getOrDefault("userTag",""); + endTag = variables.getOrDefault("endTag", ""); llmPrefix = variables.get("llmPrefix"); modelPath = modelDir + "/" + variables.get("llmModelName"); loadVariables(userConfigFilePath); @@ -95,7 +97,7 @@ public class LlmTestJNI { @Test public void testConfigLoading() { - LlmConfig llmConfig = new LlmConfig(modelTag, stopWords, modelPath, llmPrefix,userTag,numThreads); + LlmConfig llmConfig = new LlmConfig(modelTag, stopWords, modelPath, llmPrefix, userTag, endTag, numThreads); assertTrue("Model tag is not empty", !llmConfig.getModelTag().isEmpty()); assertTrue("LLM prefix is not empty", !llmConfig.getLlmPrefix().isEmpty()); assertTrue("Stop words list is not empty", !llmConfig.getStopWords().isEmpty()); @@ -103,11 +105,11 @@ public class LlmTestJNI { @Test public void testLlmPrefixSetting() { - LlmConfig llmConfig = new LlmConfig(modelTag, stopWords, modelPath, llmPrefix,userTag); + LlmConfig llmConfig = new LlmConfig(modelTag, stopWords, modelPath, llmPrefix, userTag, endTag, numThreads); Llm llm = new Llm(); llm.llmInit(llmConfig); - String newModelTag = ("Ferdia"); + String newModelTag = ("Ferdia:"); String newPrefix = "Transcript of a dialog, where the User interacts with an AI Assistant named " + newModelTag + ". " + newModelTag + " is helpful, polite, honest, good at writing and answers honestly with a maximum of two sentences. User:"; @@ -123,27 +125,26 @@ public class LlmTestJNI { @Test public void testInferenceWithContextReset() { - LlmConfig llmConfig = new LlmConfig(modelTag, stopWords, modelPath, llmPrefix,userTag,numThreads); - Llm llm = new Llm(); - llm.llmInit(llmConfig); - - String question1 = "What is the capital of the country, Morocco?"; - String response1 = llm.send(question1); - checkLlmMatch(response1, "Rabat", true); + LlmConfig llmConfig = new LlmConfig(modelTag, stopWords, modelPath, llmPrefix, userTag, endTag, numThreads); + Llm llm = new Llm(); + llm.llmInit(llmConfig); - // Resetting context should cause model to forget what country is being referred to - llm.resetContext(); + String question1 = "What is the capital of the country, Morocco?"; + String response1 = llm.send(question1); + checkLlmMatch(response1, "Rabat", true); - String question2 = "What languages do they speak there?"; - String response2 = llm.send(question2); - checkLlmMatch(response2, "Arabic", false); + // Resetting context should cause model to forget what country is being referred to + llm.resetContext(); - llm.freeModel(); + String question2 = "What languages do they speak there?"; + String response2 = llm.send(question2); + checkLlmMatch(response2, "Arabic", false); + llm.freeModel(); } @Test public void testInferenceWithoutContextReset() { - LlmConfig llmConfig = new LlmConfig(modelTag, stopWords, modelPath, llmPrefix,userTag,numThreads); + LlmConfig llmConfig = new LlmConfig(modelTag, stopWords, modelPath, llmPrefix, userTag, endTag, numThreads); Llm llm = new Llm(); llm.llmInit(llmConfig); @@ -154,13 +155,12 @@ public class LlmTestJNI { String question2 = "What languages do they speak there?"; String response2 = llm.send(question2); checkLlmMatch(response2, "Arabic", true); - llm.freeModel(); } @Test public void testInferenceHandlesEmptyQuestion() { - LlmConfig llmConfig = new LlmConfig(modelTag, stopWords, modelPath, llmPrefix,userTag,numThreads); + LlmConfig llmConfig = new LlmConfig(modelTag, stopWords, modelPath, llmPrefix, userTag, endTag, numThreads); Llm llm = new Llm(); llm.llmInit(llmConfig); @@ -177,14 +177,13 @@ public class LlmTestJNI { String question2 = "What languages do they speak there?"; String response2 = llm.send(question2); checkLlmMatch(response2, "Arabic", true); - llm.freeModel(); } @Test public void testMangoSubtractionLongConversation() { - LlmConfig llmConfig = new LlmConfig(modelTag, stopWords, modelPath, llmPrefix,userTag,numThreads); + LlmConfig llmConfig = new LlmConfig(modelTag, stopWords, modelPath, llmPrefix, userTag, endTag, numThreads); Llm llm = new Llm(); llm.llmInit(llmConfig); @@ -195,8 +194,8 @@ public class LlmTestJNI { // Set the initial ground truth in the conversation. String initialContext = "There are " + originalMangoes + " mangoes in a basket."; String initResponse = llm.send(initialContext); - String originalQuery = "How many mangoes did we start with? "; - String subtractQuery = "Subtract 1 mango. How many mangoes are left now? "; + String originalQuery = "How many mangoes did we start with?"; + String subtractQuery = "Remove 1 mango from the basket. How many mangoes left in the basket now?"; // **Assert that the model acknowledges the initial count of mangoes.** checkLlmMatch(initResponse, String.valueOf(originalMangoes), true); @@ -204,6 +203,11 @@ public class LlmTestJNI { // Loop to subtract 1 mango each iteration until reaching 0. for (int i = 1; i < originalMangoes; i++) { + // Modify the query during the conversation + if (i == 2) { + subtractQuery = "Good, remove 1 mango again from the basket. How many mangoes left in the basket now?"; + } + // Query to subtract one mango String subtractionResponse = llm.send(subtractQuery); mangoes -= 1; // Update our expected count @@ -219,6 +223,7 @@ public class LlmTestJNI { } String postResetResponse = llm.send(originalQuery); + checkLlmMatch(postResetResponse, String.valueOf(originalMangoes), false); llm.freeModel(); } @@ -232,7 +237,7 @@ public class LlmTestJNI { throw new RuntimeException("System properties for model_dir or config_file are not set!"); } - LlmConfig llmConfig = new LlmConfig(modelTag, stopWords, modelPath, llmPrefix,userTag,numThreads); + LlmConfig llmConfig = new LlmConfig(modelTag, stopWords, modelPath, llmPrefix, userTag, endTag, numThreads); Llm llm = new Llm(); llm.llmInit(llmConfig); @@ -257,8 +262,6 @@ public class LlmTestJNI { checkLlmMatch(response4, "Arabic", true); checkLlmMatch(response4, "French", true); - - // Free model after use llm.freeModel(); } } -- GitLab