diff --git a/CMakeLists.txt b/CMakeLists.txt
index 000e87c23be336675e01688268d9eaa555c018b9..729649e51dac2369f8a0d4985b4095bbce33e9f0 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -13,10 +13,6 @@ if (NOT CMAKE_TOOLCHAIN_FILE)
set(CMAKE_TOOLCHAIN_FILE "${CMAKE_CURRENT_SOURCE_DIR}/scripts/cmake/toolchains/native.cmake")
endif()
-set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
-set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
-set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
-
message(STATUS "Using CMAKE_TOOLCHAIN_FILE: ${CMAKE_TOOLCHAIN_FILE}")
# Declare project
@@ -28,7 +24,7 @@ project(arm-llm-wrapper
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/scripts/cmake")
# Set the runtime and lib directories
-set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
+set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib/archive)
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
diff --git a/README.md b/README.md
index 9250a178c6f1cb5c1c4979d99df827255ee578e9..adb5745479c9f68b722f645dc78a72437c80950d 100644
--- a/README.md
+++ b/README.md
@@ -12,8 +12,12 @@
* [Prerequisites](#prerequisites)
* [Configuration options](#configuration-options)
* [Conditional options](#conditional-options)
+ * [llama cpp options](#llama-cpp-options)
+ * [onnxruntime genai options](#onnxruntime-genai-options)
* [Quick start](#quick-start)
* [Neural network](#neural-network)
+ * [llama cpp model](#llama-cpp-model)
+ * [onnxruntime genai model](#onnxruntime-genai-model)
* [To build for Android](#to-build-for-android)
* [To build for Linux](#to-build-for-linux)
* [Generic aarch64 target](#generic-aarch64-target)
@@ -21,6 +25,8 @@
* [Native host build](#native-host-build)
* [Building and running tests](#building-and-running-tests)
* [To build an executable](#to-build-an-executable)
+ * [llama cpp](#llama-cpp)
+ * [onnxruntime genai](#onnxruntime-genai)
* [Trademarks](#trademarks)
* [License](#license)
@@ -29,8 +35,9 @@ This repo is designed for building an
[Arm® KleidiAI™](https://www.arm.com/markets/artificial-intelligence/software/kleidi)
enabled LLM library using CMake build system. It intends to provide an abstraction for different Machine Learning
frameworks/backends that Arm® KleidiAI™ kernels have been integrated into.
-Currently, it supports [llama.cpp](https://github.com/ggml-org/llama.cpp) backend but we intend to add support for
-other backends soon.
+Currently, it supports [llama.cpp](https://github.com/ggml-org/llama.cpp) and
+[onnxruntime-genai](https://github.com/microsoft/onnxruntime-genai) backends but we intend to add
+support for other backends, such as [mediapipe](https://github.com/google-ai-edge/mediapipe), soon.
The backend library (selected at CMake configuration stage) is wrapped by this project's thin C++ layer that could be used
directly for testing and evaluations. However, JNI bindings are also provided for developers targeting Android™ based
@@ -53,11 +60,9 @@ applications.
The project is designed to download the required software sources based on user
provided configuration options. CMake presets are available to use and set the following variables:
-- `LLM_DEP_NAME`: Currently supports only `llama.cpp`. Support for `mediapipe` and `executorch` may be added later.
-- `BUILD_SHARED_LIBS`: Build shared instead of static dependency libraries, specifically - ggml and common, disabled by default.
+- `LLM_FRAMEWORK`: Currently supports `llama.cpp` (default framework) and `onnxruntime-genai`.
- `BUILD_JNI_LIB`: Build the JNI shared library that other projects can consume, enabled by default.
- `BUILD_UNIT_TESTS`: Build C++ unit tests and add them to CTest, JNI tests will also be built, enabled by default.
-- `LLAMA_BUILD_COMMON`: Build llama's dependency Common, enabled by default.
- `BUILD_EXECUTABLE`: Build standalone applications, disabled by default.
> **NOTE**: If you need specific version of Java set the path in `JAVA_HOME` environment variable.
@@ -75,29 +80,83 @@ provided configuration options. CMake presets are available to use and set the f
### Conditional options
-For `llama.cpp` as dependency, these configuration parameters can be set:
+There are different conditional options for different frameworks.
+
+#### llama cpp options
+
+For `llama.cpp` as framework, these configuration parameters can be set:
- `LLAMA_SRC_DIR`: Source directory path that will be populated by CMake
configuration.
- `LLAMA_GIT_URL`: Git URL to clone the sources from.
- `LLAMA_GIT_SHA`: Git SHA for checkout.
+- `LLAMA_BUILD_COMMON`: Build llama's dependency Common, enabled by default.
+- `BUILD_SHARED_LIBS`: Build shared instead of static dependency libraries, specifically - ggml and common, disabled by default.
+- `LLAMA_CURL`: Enable HTTP transport via libcurl for remote models or features requiring network communication, disabled by default.
+
+#### onnxruntime genai options
+
+When using `onnxruntime-genai`, the `onnxruntime` dependency will be built from source. To customize
+the versions of both `onnxruntime` and `onnxruntime-genai`, the following configuration parameters
+can be used:
+
+onnxruntime:
+- `ONNXRUNTIME_SRC_DIR`: Source directory path that will be populated by CMake
+ configuration.
+- `ONNXRUNTIME_GIT_URL`: Git URL to clone the sources from.
+- `ONNXRUNTIME_GIT_TAG`: Git SHA for checkout.
+
+onnxruntime-genai:
+- `ONNXRT_GENAI_SRC_DIR`: Source directory path that will be populated by CMake
+ configuration.
+- `ONNXRT_GENAI_GIT_URL`: Git URL to clone the sources from.
+- `ONNXRT_GENAI_GIT_TAG`: Git SHA for checkout.
+
+> **NOTE**: This repository has been tested with `onnxruntime` version `v1.22.0` and
+`onnxruntime-genai` version `v0.8.2`.
## Quick start
By default, the JNI builds are enabled, and Arm® KleidiAI™ kernels are enabled on arm64/aarch64.
-To disable these, configure with: `-DGGML_CPU_KLEIDIAI=OFF`.
+To disable these, configure with: `-DUSE_KLEIDIAI=OFF`.
### Neural network
-This project uses the **phi-2 model** as its default network. The model is distributed using the
-**Q4_0 quantization format**, which is highly recommended as it delivers effective inference times by striking a
-balance between computational efficiency and model performance.
+There are different default model for different frameworks.
+
+#### llama cpp model
+
+This project uses the **phi-2 model** as its default network for `llama.cpp` framework.
+The model is distributed using the **Q4_0 quantization format**, which is highly recommended as it
+delivers effective inference times by striking a balance between computational efficiency and model performance.
- You can access the model from [Hugging Face](https://huggingface.co/ggml-org/models/blob/main/phi-2/ggml-model-q4_0.gguf).
- The default model configuration is declared in the [`requirements.json`](scripts/py/requirements.json) file.
However, any model supported by the backend library could be used.
-> **NOTE**: Currently only Q4_0 models are accelerated by Arm® KleidiAI™ kernels in llama.cpp.
+> **NOTE**: Currently only Q4_0 models are accelerated by Arm® KleidiAI™ kernels in `llama.cpp`.
+
+#### onnxruntime genai model
+
+This project uses the **Phi-4-mini-instruct-onnx** as its default network for `onnxruntime-genai` framework.
+The model is distributed using **int4 quantization format** with the **block size: 32**, which is highly recommended as it
+delivers effective inference times by striking a balance between computational efficiency and model performance.
+
+- You can access the model from [Hugging Face](https://huggingface.co/microsoft/Phi-4-mini-instruct-onnx/tree/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4).
+- The default model configuration is declared in the [`requirements.json`](scripts/py/requirements.json) file.
+
+However, any model supported by the backend library could be used.
+
+To use an ONNX model with this framework, the following files are required:
+- `genai_config.json`: Configuration file
+- `model_name.onnx`: ONNX model
+- `model_name.onnx.data`: ONNX model data
+- `tokenizer.json`: Tokenizer file
+- `tokenizer_config.json`: Tokenizer config file
+
+These files are essential for loading and running ONNX models effectively.
+
+> **NOTE**: Currently only int4 and block size 32 models are accelerated by Arm® KleidiAI™ kernels in `onnxruntime-genai`.
### To build for Android
For Android™ build, ensure the `NDK_PATH` is set to installed Android™ NDK, specify Android™ ABI and platform needed:
@@ -114,7 +173,7 @@ cmake --build ./build
### To build for Linux
-Building for Linux targets, with llama.cpp backend, `GGML_CPU_ARM_ARCH` can be set to provide the architecture flags.
+Building for Linux targets, with `llama.cpp` backend, `GGML_CPU_ARM_ARCH` can be set to provide the architecture flags.
#### Generic aarch64 target
@@ -131,7 +190,7 @@ cmake --build ./build
#### Aarch64 target with SME
-To build for aarch64 Linux system with [Scalable Matrix Extensions](https://developer.arm.com/documentation/109246/0100/SME-Overview/SME-and-SME2), ensure `GGML_CPU_ARM_ARCH` is set with needed feature flags as below:
+To build for aarch64 Linux system with [Scalable Matrix Extensions](https://developer.arm.com/documentation/109246/0100/SME-Overview/SME-and-SME2), for `llama.cpp` ensure `GGML_CPU_ARM_ARCH` is set with needed feature flags as below:
```shell
cmake -B build \
@@ -145,16 +204,16 @@ cmake --build ./build
Once built, a standalone application can be executed to get performance.
If `FEAT_SME` is available on deployment target, environment variable `GGML_KLEIDIAI_SME` can be used to
-toggle the use of SME kernels during execution. For example:
+toggle the use of SME kernels during execution for `llama.cpp`. For example:
```shell
-GGML_KLEIDIAI_SME=1 ./build/bin/llama-cli -m resources_downloaded/models/model.gguf -t 1 -p "What is a car?"
+GGML_KLEIDIAI_SME=1 ./build/bin/llama-cli -m resources_downloaded/models/llama.cpp/model.gguf -t 1 -p "What is a car?"
```
To run without invoking SME kernels, set `GGML_KLEIDIAI_SME=0` during execution:
```shell
-GGML_KLEIDIAI_SME=0 ./build/bin/llama-cli -m resources_downloaded/models/model.gguf -t 1 -p "What is a car?"
+GGML_KLEIDIAI_SME=0 ./build/bin/llama-cli -m resources_downloaded/models/llama.cpp/model.gguf -t 1 -p "What is a car?"
```
> **NOTE**: In some cases, it may be desirable to build a statically linked executable. For llama.cpp backend
@@ -181,6 +240,8 @@ cmake --build ./build
ctest --test-dir ./build
```
+> **NOTE**: For consistent and reliable test results, avoid using the `--parallel` option when running tests.
+
This should produce something like:
```shell
Internal ctest changing into directory: /home/user/llm/build
@@ -221,12 +282,22 @@ cmake -B build \
cmake --build ./build
```
+### llama cpp
+
You can run either executable from command line and add your prompt for example the following:
```
-./build/bin/llama-cli -m resources_downloaded/models/model.gguf --prompt "What is the capital of France"
+./build/bin/llama-cli -m resources_downloaded/models/llama.cpp/model.gguf --prompt "What is the capital of France"
```
More information can be found at `llama.cpp/examples/main/README.md` on how this executable can be run.
+### onnxruntime genai
+
+You can run model_benchmark executable from command line:
+```
+./build/bin/model_benchmark -i resources_downloaded/models/onnxruntime-genai
+```
+More information can be found at `onnxruntime-genai/benchmark/c/readme.md` on how this executable can be run.
+
## Trademarks
* Arm® and KleidiAI™ are registered trademarks or trademarks of Arm® Limited (or its subsidiaries) in the US and/or
diff --git a/model_configuration_files/llamaConfig.txt b/model_configuration_files/llamaConfig.txt
index 457f2b60170b5fda6a6235077a7d99946dc0ec9d..b3b8a174206bb2df4b0fee1faaf1cb166f093761 100755
--- a/model_configuration_files/llamaConfig.txt
+++ b/model_configuration_files/llamaConfig.txt
@@ -1,4 +1,6 @@
modelTag=Orbita:
+userTag=
+endTag=
llmPrefix=Transcript of a dialog, where the User interacts with an AI Assistant named Orbita. Orbita is helpful, polite, honest, good at writing and answers honestly with a maximum of two sentences. User:
stopWords=Orbita:,User:,AI:,<|user|>,Assistant:,user:,[end of text],<|endoftext|>,model:,Question:,"\n\n",Consider the following scenario:\n
-llmModelName=model.gguf
\ No newline at end of file
+llmModelName=llama.cpp/model.gguf
diff --git a/model_configuration_files/onnxrtConfig.txt b/model_configuration_files/onnxrtConfig.txt
new file mode 100755
index 0000000000000000000000000000000000000000..8b6ccd7f06c08d2c649ff0ea8893f1dfecbb5e8c
--- /dev/null
+++ b/model_configuration_files/onnxrtConfig.txt
@@ -0,0 +1,6 @@
+modelTag=<|assistant|>
+userTag=<|user|>
+endTag=<|end|>
+llmPrefix=<|system|>Transcript of a dialog, where the User interacts with an AI Assistant named Orbita. Orbita is helpful, polite, honest, good at writing and answers honestly with a maximum of two sentences<|end|>
+stopWords=Orbita:,User:,User,AI:,<|user|>,Assistant:,<|assistant|>,<|end|>,user:,[end of text],<|endoftext|>,model:,Question:,Consider the following scenario:\n
+llmModelName=onnxruntime-genai
diff --git a/scripts/cmake/check-flag.cmake b/scripts/cmake/check-flag.cmake
index f714a2e4d8dca95e0f332ac6ffc2602e17489662..d37f10cad356748bb8adc72bf503368aac97cec6 100644
--- a/scripts/cmake/check-flag.cmake
+++ b/scripts/cmake/check-flag.cmake
@@ -38,3 +38,23 @@ function(check_compiler_support LANG FLAG)
message(STATUS "The compiler supports the ${LANG} flag ${FLAG}!")
endif()
endfunction()
+
+function(set_kleidiai_flag)
+
+ # If the user has NOT explicitly set onnxruntime_USE_KLEIDIAI
+ if (NOT DEFINED USE_KLEIDIAI)
+ # if we are on arm64/aarch64, then default KleidiAI to ON.
+ if (CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm64|ARM64)$")
+ set(USE_KLEIDIAI ON CACHE BOOL
+ "Enable KleidiAI by default on ${CMAKE_SYSTEM_PROCESSOR}")
+ message(STATUS "KleidiAI enabled by default")
+ # if we are NOT on arm64/aarch64, then default KleidiAI to OFF.
+ else()
+ set(USE_KLEIDIAI OFF CACHE BOOL
+ "Disable KleidiAI by default on ${CMAKE_SYSTEM_PROCESSOR}")
+ message(STATUS "KleidiAI disabled by default")
+ endif()
+ else ()
+ message(STATUS "KleidiAI: ${USE_KLEIDIAI}")
+ endif()
+endfunction()
\ No newline at end of file
diff --git a/scripts/cmake/configuration-options.cmake b/scripts/cmake/configuration-options.cmake
index 8f8f60d3596207de53b4ee374cfce5bd76a12b50..17787c8646e6eab94fb0d5eb5a2d81a2ad9e192f 100644
--- a/scripts/cmake/configuration-options.cmake
+++ b/scripts/cmake/configuration-options.cmake
@@ -6,14 +6,13 @@
include_guard(GLOBAL)
-set(LLM_DEP_NAME "llama.cpp" CACHE STRING
+set(LLM_FRAMEWORK "llama.cpp" CACHE STRING
"Dependency name to configure the project for")
# Available options:
-set(CACHE LLM_DEP_NAME PROPERTY STRINGS
+set(CACHE LLM_FRAMEWORK PROPERTY STRINGS
"llama.cpp"
- "mediapipe"
- "executorch")
+ "onnxruntime-genai")
set(DOWNLOADS_DIR ${CMAKE_CURRENT_SOURCE_DIR}/resources_downloaded
CACHE STRING
diff --git a/scripts/cmake/configuration-presets.json b/scripts/cmake/configuration-presets.json
index 9c6fd46ea12ce7b8ff59561be0198ee764a21734..2d1cd7025a39abc7f8e685d2e97ff207153eb283 100644
--- a/scripts/cmake/configuration-presets.json
+++ b/scripts/cmake/configuration-presets.json
@@ -50,39 +50,38 @@
}
}
},
-
- {
- "name": "llama.cpp",
- "description": "Build with llama.cpp as the backend.",
- "hidden": true,
- "cacheVariables": {
- "LLM_DEP_NAME": {
- "type": "STRING",
- "value": "llama.cpp"
- }
- }
- },
- {
- "name": "build-shared-enabled",
- "description": "Configure with shared libraries enabled.",
- "hidden": true,
- "cacheVariables": {
- "BUILD_SHARED_LIBS": {
- "type": "BOOL",
- "value": "ON"
- }
- }
- },
- {
- "name": "build-shared-disabled",
- "description": "Configure with shared libraries disabled.",
- "hidden": true,
- "cacheVariables": {
- "BUILD_SHARED_LIBS": {
- "type": "BOOL",
- "value": "OFF"
- }
- }
- }
- ]
+ {
+ "name": "llama.cpp",
+ "description": "Build with llama.cpp as the backend.",
+ "hidden": true,
+ "cacheVariables": {
+ "LLM_FRAMEWORK": {
+ "type": "STRING",
+ "value": "llama.cpp"
+ }
+ }
+ },
+ {
+ "name": "build-shared-enabled",
+ "description": "Configure with shared libraries enabled.",
+ "hidden": true,
+ "cacheVariables": {
+ "BUILD_SHARED_LIBS": {
+ "type": "BOOL",
+ "value": "ON"
+ }
+ }
+ },
+ {
+ "name": "build-shared-disabled",
+ "description": "Configure with shared libraries disabled.",
+ "hidden": true,
+ "cacheVariables": {
+ "BUILD_SHARED_LIBS": {
+ "type": "BOOL",
+ "value": "OFF"
+ }
+ }
+ }
+ ]
}
diff --git a/scripts/cmake/download-resources.cmake b/scripts/cmake/download-resources.cmake
index 4b337e0d4fbbdf17263a08c71e22765bc9f64888..b523529a33dab5ad0e0bc9e87f7d758963caf0d4 100644
--- a/scripts/cmake/download-resources.cmake
+++ b/scripts/cmake/download-resources.cmake
@@ -45,6 +45,8 @@ execute_process(
${CMAKE_CURRENT_LIST_DIR}/../py/requirements.json
--download-dir
${DOWNLOADS_DIR}
+ --llm-framework
+ ${LLM_FRAMEWORK}
RESULT_VARIABLE return_code)
# Release the lock:
diff --git a/scripts/py/download_resources.py b/scripts/py/download_resources.py
index 22ccbc32ccb200257df7cda326de7922f4db6f6f..5794ab820c671d671b5adbf1f991727968ca3d8f 100644
--- a/scripts/py/download_resources.py
+++ b/scripts/py/download_resources.py
@@ -11,6 +11,7 @@ import urllib.request
import logging
import sys
from argparse import ArgumentParser
+
def download_file(url: str, dest: Path) -> None:
"""
Download a file
@@ -59,12 +60,20 @@ def download_resources(resources_file: Path, download_dir: Path) -> None:
for resource_type in resource_list:
resource_dir = Path(download_dir / resource_type)
resource_dir.mkdir(exist_ok=True)
- for resource_data in resource_list[resource_type]:
- logging.info(f'Name: {resource_data["name"]}')
- logging.info(f'Purpose: {resource_data["purpose"]}')
- logging.info(f'Dest: {resource_data["destination"]}')
- logging.info(f'URL: {resource_data["url"]}')
- logging.info(f'SHA256: {resource_data["sha256sum"]}')
+
+ if resource_type == "models":
+ resources = resource_list[resource_type][llm_framework]
+ model_dir = Path(resource_dir / llm_framework)
+ model_dir.mkdir(exist_ok=True)
+ else:
+ resources = resource_list[resource_type]
+
+ for resource_data in resources:
+ logging.info(f'Name: {resource_data["name"]}')
+ logging.info(f'Purpose: {resource_data["purpose"]}')
+ logging.info(f'Dest: {resource_data["destination"]}')
+ logging.info(f'URL: {resource_data["url"]}')
+ logging.info(f'SHA256: {resource_data["sha256sum"]}')
url = resource_data['url']
dest = resource_dir / resource_data['destination']
@@ -83,6 +92,7 @@ def download_resources(resources_file: Path, download_dir: Path) -> None:
current_file_dir = Path(__file__).parent.resolve()
default_requirements_file = current_file_dir / 'requirements.json'
default_download_location = current_file_dir / '..' / '..' / 'resources_downloaded'
+default_llm_framework = "llama.cpp"
if __name__ == "__main__":
@@ -98,10 +108,15 @@ if __name__ == "__main__":
"--download-dir",
help="Path to where resources should be downloaded.",
default=default_download_location)
+ parser.add_argument(
+ "--llm-framework",
+ help="LLM framework from which the model will be downloaded.",
+ default=default_llm_framework)
args = parser.parse_args()
req_file = Path(args.requirements_file)
download_dir = Path(args.download_dir)
+ llm_framework = args.llm_framework
if not req_file.exists():
raise FileNotFoundError(f'{req_file} does not exist')
diff --git a/scripts/py/requirements.json b/scripts/py/requirements.json
index 1340e7c7cba4f80a0388a8decc239af4686c9bc1..a3f81fbfc24e3e7f98d8e58077b5e7cddb7f95dd 100644
--- a/scripts/py/requirements.json
+++ b/scripts/py/requirements.json
@@ -1,27 +1,70 @@
{
- "models": [
- {
+ "models": {
+ "llama.cpp":
+ [
+ {
"name": "phi-2 model",
"purpose": "To enable basic testing",
- "destination": "model.gguf",
+ "destination": "llama.cpp/model.gguf",
"url": "https://huggingface.co/ggml-org/models/resolve/main/phi-2/ggml-model-q4_0.gguf?download=true",
"sha256sum": "fd506d24a4bee6997a566b02b65715af5cadb433c6a3a47a74b467badc5727ca"
- }
- ],
- "test_utils": [
+ }
+ ],
+ "onnxruntime-genai" :
+ [
+ {
+ "name": "genai config file",
+ "purpose": "Model configuration file",
+ "destination": "onnxruntime-genai/genai_config.json",
+ "url": "https://huggingface.co/microsoft/Phi-4-mini-instruct-onnx/resolve/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/genai_config.json",
+ "sha256sum": "0fcfa1e663f2bc867f8dc62fae65dd0924f0a4d68b43d1234df742dd19171470"
+ },
+ {
+ "name": "phi-4-mini-instruct-onnx (CPU/mobile)",
+ "purpose": "ONNX model for basic testing",
+ "destination": "onnxruntime-genai/model.onnx",
+ "url": "https://huggingface.co/microsoft/Phi-4-mini-instruct-onnx/resolve/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/model.onnx",
+ "sha256sum": "701aa5d185b6a782bc27104a990dd5b634fa507840b7c42f7ee6f1fb812d0b83"
+ },
+ {
+ "name": "phi-4-mini-instruct-onnx.data (CPU/mobile)",
+ "purpose": "ONNX model data for basic testing",
+ "destination": "onnxruntime-genai/model.onnx.data",
+ "url": "https://huggingface.co/microsoft/Phi-4-mini-instruct-onnx/resolve/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/model.onnx.data",
+ "sha256sum": "cb0267fa60befa1a4ade8c98b6d32a3d67f51abbd307c7f793f132e8d9092131"
+ },
+ {
+ "name": "tokenizer file",
+ "purpose": "Tokenizer for the model",
+ "destination": "onnxruntime-genai/tokenizer.json",
+ "url": "https://huggingface.co/microsoft/Phi-4-mini-instruct-onnx/resolve/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/tokenizer.json",
+ "sha256sum": "382cc235b56c725945e149cc25f191da667c836655efd0857b004320e90e91ea"
+ },
+ {
+ "name": "tokenizer config file",
+ "purpose": "Tokenizer config for the model",
+ "destination": "onnxruntime-genai/tokenizer_config.json",
+ "url": "https://huggingface.co/microsoft/Phi-4-mini-instruct-onnx/resolve/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/tokenizer_config.json",
+ "sha256sum": "c565326a315fbe62cda093a59d298828c8f3f823122661325f41f3ba577a7dec"
+ }
+ ]
+ },
+
+ "test_utils":
+ [
{
- "name": "hamcrest",
- "purpose": "JUnit tests for validating the Java API",
- "destination": "hamcrest-all-1.3.jar",
- "url": "https://repo1.maven.org/maven2/org/hamcrest/hamcrest-all/1.3/hamcrest-all-1.3.jar",
- "sha256sum": "4877670629ab96f34f5f90ab283125fcd9acb7e683e66319a68be6eb2cca60de"
+ "name": "hamcrest",
+ "purpose": "JUnit tests for validating the Java API",
+ "destination": "hamcrest-all-1.3.jar",
+ "url": "https://repo1.maven.org/maven2/org/hamcrest/hamcrest-all/1.3/hamcrest-all-1.3.jar",
+ "sha256sum": "4877670629ab96f34f5f90ab283125fcd9acb7e683e66319a68be6eb2cca60de"
},
{
- "name": "JUnit",
- "purpose": "JUnit tests for validating the Java API",
- "destination": "junit-4.13.2.jar",
- "url": "https://repo1.maven.org/maven2/junit/junit/4.13.2/junit-4.13.2.jar",
- "sha256sum": "8e495b634469d64fb8acfa3495a065cbacc8a0fff55ce1e31007be4c16dc57d3"
+ "name": "JUnit",
+ "purpose": "JUnit tests for validating the Java API",
+ "destination": "junit-4.13.2.jar",
+ "url": "https://repo1.maven.org/maven2/junit/junit/4.13.2/junit-4.13.2.jar",
+ "sha256sum": "8e495b634469d64fb8acfa3495a065cbacc8a0fff55ce1e31007be4c16dc57d3"
}
]
-}
+}
\ No newline at end of file
diff --git a/src/cpp/Llm.cpp b/src/cpp/Llm.cpp
index 5e337bb2541bd040c2514c540bd45e01e5028667..38ff2c39f8963aab538496de9619dd606161f9bc 100644
--- a/src/cpp/Llm.cpp
+++ b/src/cpp/Llm.cpp
@@ -69,3 +69,8 @@ std::string LLM::BenchModel(int& nPrompts, int& nEvalPrompts, int& nMaxSeq, int&
{
return this->m_impl->BenchModel(nPrompts, nEvalPrompts, nMaxSeq, nRep);
}
+
+std::string LLM::GetFrameworkType()
+{
+ return this->m_impl->GetFrameworkType();
+}
diff --git a/src/cpp/LlmJni.cpp b/src/cpp/LlmJni.cpp
index 4f595dd1533cf089a4b74f16c9ed05d38f9a59cd..662a6e186f7445334e31d2e9e7babe76f746c9e7 100644
--- a/src/cpp/LlmJni.cpp
+++ b/src/cpp/LlmJni.cpp
@@ -8,7 +8,7 @@
#include
-static std::unique_ptr llm{nullptr};
+static std::unique_ptr llm = std::make_unique();
#ifdef __cplusplus
extern "C" {
@@ -17,16 +17,22 @@ extern "C" {
JNIEXPORT jlong JNICALL Java_com_arm_Llm_createLlmConfig(JNIEnv* env,
jobject /* this */,
jstring jModelTag,
+ jstring jUserTag,
+ jstring jEndTag,
jstring jModelPath,
jstring jLlmPrefix,
jint jNumThreads,
jint jBatchSize)
{
const char* modelTag = env->GetStringUTFChars(jModelTag, nullptr);
+ const char* userTag = env->GetStringUTFChars(jUserTag, nullptr);
+ const char* endTag = env->GetStringUTFChars(jEndTag, nullptr);
const char* modelPath = env->GetStringUTFChars(jModelPath, nullptr);
const char* llmPrefix = env->GetStringUTFChars(jLlmPrefix, nullptr);
auto* config = new LlmConfig(std::string(modelTag),
+ std::string(userTag),
+ std::string(endTag),
std::string(modelPath),
std::string(llmPrefix),
static_cast(jNumThreads),
@@ -34,6 +40,8 @@ JNIEXPORT jlong JNICALL Java_com_arm_Llm_createLlmConfig(JNIEnv* env,
// Clean up
env->ReleaseStringUTFChars(jModelTag, modelTag);
+ env->ReleaseStringUTFChars(jUserTag, userTag);
+ env->ReleaseStringUTFChars(jEndTag, endTag);
env->ReleaseStringUTFChars(jModelPath, modelPath);
env->ReleaseStringUTFChars(jLlmPrefix, llmPrefix);
@@ -43,7 +51,6 @@ JNIEXPORT jlong JNICALL Java_com_arm_Llm_createLlmConfig(JNIEnv* env,
JNIEXPORT jlong JNICALL Java_com_arm_Llm_loadModel(JNIEnv* env, jobject, jlong pconfig)
{
auto config = reinterpret_cast(pconfig);
- llm = std::make_unique();
llm->LlmInit(*config);
return 0;
}
@@ -74,7 +81,6 @@ JNIEXPORT jfloat JNICALL Java_com_arm_Llm_getEncodeRate(JNIEnv* env, jobject)
JNIEXPORT jfloat JNICALL Java_com_arm_Llm_getDecodeRate(JNIEnv* env, jobject)
{
-
float result = llm->GetDecodeTimings();
return result;
}
@@ -83,14 +89,17 @@ JNIEXPORT void JNICALL Java_com_arm_Llm_resetTimings(JNIEnv* env, jobject)
{
llm->ResetTimings();
}
+
JNIEXPORT jsize JNICALL Java_com_arm_Llm_getChatProgress(JNIEnv* env, jobject)
{
return llm->GetChatProgress();
}
+
JNIEXPORT void JNICALL Java_com_arm_Llm_resetContext(JNIEnv* env, jobject)
{
llm->ResetContext();
}
+
JNIEXPORT jstring JNICALL Java_com_arm_Llm_benchModel(
JNIEnv* env, jobject, jint nPrompts, jint nEvalPrompts, jint nMaxSeq, jint nRep)
{
@@ -98,6 +107,12 @@ JNIEXPORT jstring JNICALL Java_com_arm_Llm_benchModel(
return env->NewStringUTF(result.c_str());
}
+JNIEXPORT jstring JNICALL Java_com_arm_Llm_getFrameworkType(JNIEnv* env, jobject)
+{
+ std::string frameworkType = llm->GetFrameworkType();
+ return env->NewStringUTF(frameworkType.c_str());
+}
+
#ifdef __cplusplus
}
#endif
diff --git a/src/cpp/config/LlmConfig.cpp b/src/cpp/config/LlmConfig.cpp
index db5f95fd8bf65bf9039ccc81258fa4a4c8ad9902..f80c886f9667839b92b973f8e45740fc0106a9e2 100644
--- a/src/cpp/config/LlmConfig.cpp
+++ b/src/cpp/config/LlmConfig.cpp
@@ -7,16 +7,28 @@
#include
LlmConfig::LlmConfig(const std::string& modelTag,
+ const std::string& userTag,
+ const std::string& endTag,
const std::string& modelPath,
const std::string& llmPrefix,
int numThreads,
int batchSize) :
- m_modelTag(modelTag), m_modelPath(modelPath), m_llmPrefix(llmPrefix)
+ m_modelTag(modelTag), m_userTag(userTag), m_endTag(endTag), m_modelPath(modelPath), m_llmPrefix(llmPrefix)
{
SetNumThreads(numThreads);
SetBatchSize(batchSize);
}
+std::string LlmConfig::GetEndTag() const
+{
+ return this->m_endTag;
+}
+
+std::string LlmConfig::GetUserTag() const
+{
+ return this->m_userTag;
+}
+
std::string LlmConfig::GetModelTag() const
{
return this->m_modelTag;
@@ -47,6 +59,16 @@ void LlmConfig::SetModelTag(const std::string& modelIdentifier)
this->m_modelTag = modelIdentifier;
}
+void LlmConfig::SetUserTag(const std::string& userTag)
+{
+ this->m_userTag = userTag;
+}
+
+void LlmConfig::SetEndTag(const std::string& endTag)
+{
+ this->m_endTag = endTag;
+}
+
void LlmConfig::SetModelPath(const std::string& basePath)
{
this->m_modelPath = basePath;
diff --git a/src/cpp/config/LlmConfig.hpp b/src/cpp/config/LlmConfig.hpp
index 66713abd356bf93ddc3e1e71c7253066813bdc07..2216d660dbe1e14b859a183192a864b86193df57 100644
--- a/src/cpp/config/LlmConfig.hpp
+++ b/src/cpp/config/LlmConfig.hpp
@@ -15,6 +15,8 @@
class LlmConfig {
private:
std::string m_modelTag{};
+ std::string m_userTag{};
+ std::string m_endTag{};
std::string m_modelPath{};
std::string m_llmPrefix{};
int m_numThreads{};
@@ -24,12 +26,16 @@ public:
/**
* LlmConfig
* @param modelTag Model tag for the LLM model
+ * @param userTag User tag for the prompt
+ * @param endTag End tag for the prompt
* @param modelPath Path to the model
* @param llmPrefix LLM prefix to use
* @param numThreads Number of threads to use
* @param batchSize Batch size to use
*/
LlmConfig(const std::string& modelTag,
+ const std::string& userTag,
+ const std::string& endTag,
const std::string& modelPath,
const std::string& llmPrefix,
int numThreads,
@@ -37,6 +43,18 @@ public:
LlmConfig() = default;
+ /**
+ * Returns the end tag string.
+ * @return endTag
+ */
+ std::string GetEndTag() const;
+
+ /**
+ * Returns the user tag string.
+ * @return userTag
+ */
+ std::string GetUserTag() const;
+
/**
* Returns the model tag string (The name to appear in conversation with the LLM).
* @return modelTag
@@ -74,6 +92,20 @@ public:
*/
void SetModelTag(const std::string& modelIdentifier);
+ /**
+ * Sets the user tag
+ * @param userTag is the user tag added at the beginning of each user question to make model
+ * respond appropriately
+ */
+ void SetUserTag(const std::string& userTag);
+
+ /**
+ * Sets the end tag
+ * @param endTag is the end tag added at the end of each user question to make model
+ * respond appropriately
+ */
+ void SetEndTag(const std::string& endTag);
+
/**
* Sets the file path to the model.
* @param basePath absolute path to load llm model
diff --git a/src/cpp/frameworks/CMakeLists.txt b/src/cpp/frameworks/CMakeLists.txt
index 84166829cfa06fc156fc22ecff2e4e179d5464de..7a96bcfef9b0c712411a1a6d28b179068b261689 100644
--- a/src/cpp/frameworks/CMakeLists.txt
+++ b/src/cpp/frameworks/CMakeLists.txt
@@ -5,8 +5,10 @@
#
# Pull in LLM framework library
-if (${LLM_DEP_NAME} STREQUAL "llama.cpp")
+if (${LLM_FRAMEWORK} STREQUAL "llama.cpp")
add_subdirectory(llama_cpp)
+elseif(${LLM_FRAMEWORK} STREQUAL "onnxruntime-genai")
+ add_subdirectory(onnxruntime_genai)
else()
- message(FATAL_ERROR "${LLM_DEP_NAME} is currently not supported :(")
+ message(FATAL_ERROR "${LLM_FRAMEWORK} is currently not supported :(")
endif()
diff --git a/src/cpp/frameworks/llama_cpp/CMakeLists.txt b/src/cpp/frameworks/llama_cpp/CMakeLists.txt
index 08eb37d9eaafc3e2b76c45abea6d213ff96bf389..b6be17b198b7e41f1de96c78a59430d190f7b521 100644
--- a/src/cpp/frameworks/llama_cpp/CMakeLists.txt
+++ b/src/cpp/frameworks/llama_cpp/CMakeLists.txt
@@ -10,6 +10,7 @@ project(llama-cpp-wrapper
LANGUAGES C CXX ASM)
include(FetchContent)
+include(check-flag)
# Where should llama.cpp sources be cloned into?
# It might make sense to download sources into resources as well and not
@@ -29,22 +30,14 @@ set(LLAMA_GIT_SHA "a4090d1"
"Git commit SHA for llama.cpp repo")
set(LLAMA_BUILD_EXAMPLES ${BUILD_EXECUTABLE} CACHE BOOL "Build llama.cpp examples")
+set(LLAMA_CURL OFF CACHE BOOL "llama: use libcurl to download model from an URL")
-# If the user has NOT explicitly set GGML_CPU_KLEIDIAI
-if (NOT DEFINED GGML_CPU_KLEIDIAI)
- # if we are on arm64/aarch64, then default KleidiAI to ON.
- if (CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm64|ARM64)$")
- set(GGML_CPU_KLEIDIAI ON CACHE BOOL
- "Enable KleidiAI by default on ${CMAKE_SYSTEM_PROCESSOR}")
- message(STATUS "KleidiAI enabled by default")
- # if we are NOT on arm64/aarch64, then default KleidiAI to OFF.
- else()
- set(GGML_CPU_KLEIDIAI OFF CACHE BOOL
- "Disable KleidiAI by default on ${CMAKE_SYSTEM_PROCESSOR}")
- message(STATUS "KleidiAI disabled by default")
- endif()
-else ()
- message(STATUS "KleidiAI: ${GGML_CPU_KLEIDIAI}")
+# KleidiAI configuration
+if(DEFINED GGML_CPU_KLEIDIAI)
+ message(FATAL_ERROR "Don't use framework specific KleidiAI flag, configure 'USE_KLEIDIAI' flag instead.")
+else()
+ set_kleidiai_flag()
+ set(GGML_CPU_KLEIDIAI ${USE_KLEIDIAI})
endif()
set(GGML_BUILD_NUMBER 1) # We do a shallow clone `--depth=1`
diff --git a/src/cpp/frameworks/llama_cpp/LlamaImpl.cpp b/src/cpp/frameworks/llama_cpp/LlamaImpl.cpp
index e5d9e21e8b4e4053b7532f4d87c1df4a9c071038..ffc00891a5c60e4715d56fa52744159bd75a5d3e 100644
--- a/src/cpp/frameworks/llama_cpp/LlamaImpl.cpp
+++ b/src/cpp/frameworks/llama_cpp/LlamaImpl.cpp
@@ -288,6 +288,11 @@ size_t LLM::LLMImpl::GetChatProgress() const
return this->m_contextFilled;
}
+std::string LLM::LLMImpl::GetFrameworkType()
+{
+ return this->m_frameworkType;
+}
+
std::string LLM::LLMImpl::BenchModel(int& prompts, int& eval_prompts, int& n_max_sq, int& n_rep)
{
auto prompts_avg = 0.0;
@@ -435,3 +440,4 @@ static bool is_valid_utf8(const char* string)
}
return true;
}
+
diff --git a/src/cpp/frameworks/llama_cpp/LlmImpl.hpp b/src/cpp/frameworks/llama_cpp/LlmImpl.hpp
index 1d3f027c4654e699e736f8840ec438101412c2aa..b969f18ab051d3eb5d0eb51d74ad740cbd2ccc21 100644
--- a/src/cpp/frameworks/llama_cpp/LlmImpl.hpp
+++ b/src/cpp/frameworks/llama_cpp/LlmImpl.hpp
@@ -28,13 +28,13 @@ public:
~LLMImpl();
/**
- * Method to Initialize a llama_model
+ * Method to initialize a llama_model
* @param config Configuration class with model's parameter and user defined parameters
*/
void LlmInit(const LlmConfig& config);
/**
- * Method to Free all allocations pertaining to llama model
+ * Method to free all allocations pertaining to llama model
*/
void FreeLlm();
@@ -62,7 +62,7 @@ public:
std::string SystemInfo();
/**
- * Method to reset Conversation history and preserve Model's character prefix.
+ * Method to reset conversation history and preserve model's character prefix.
* If model's prefix is not defined all conversation history would be cleared
*/
void ResetContext();
@@ -75,18 +75,18 @@ public:
/**
* Method to wrap CompletionLoop function
- * @return the next Token for Encoded Prompt
+ * @return the next token for encoded prompt
*/
std::string NextToken();
/**
- * @brief The Method return the percentage of chat context filled
+ * The Method return the percentage of chat context filled
* @return chat capacity filled in cache as percentage number
*/
size_t GetChatProgress() const;
/**
- * @brief Benchmarks the performance of the LLM model.
+ * Benchmarks the performance of the LLM model.
*
* This function evaluates the model's performance by processing a specified number of prompts
* and generating text sequences. It measures the speed of prompt evaluation and text
@@ -101,10 +101,16 @@ public:
* size, number of parameters, backend information, and performance metrics for prompt
* evaluation and text generation.
*/
-
std::string BenchModel(int& prompts, int& eval_prompts, int& n_max_sq, int& n_rep);
+ /**
+ * Method to get framework type
+ * @return string framework type
+ */
+ std::string GetFrameworkType();
+
private:
+ std::string m_frameworkType{"llama.cpp"};
llama_context* m_llmContext{nullptr};
llama_model* m_llmModel{nullptr};
llama_batch m_llmBatch{};
diff --git a/src/cpp/frameworks/onnxruntime_genai/CMakeLists.txt b/src/cpp/frameworks/onnxruntime_genai/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..0528e41e30fbd05ecfb302cb648a8e2878be0da9
--- /dev/null
+++ b/src/cpp/frameworks/onnxruntime_genai/CMakeLists.txt
@@ -0,0 +1,177 @@
+#
+# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+
+# Declare project
+project(onnxruntime-genai-wrapper
+ DESCRIPTION "ONNX Runtime GenAI wrapper interface implementation"
+ LANGUAGES C CXX ASM)
+
+include(FetchContent)
+include(check-flag)
+
+# --------------------------------------------------------------------
+# onnxruntime build
+# --------------------------------------------------------------------
+
+# KleidiAI configuration
+if(DEFINED onnxruntime_USE_KLEIDIAI)
+ message(FATAL_ERROR "Don't use framework specific KleidiAI flag, configure 'USE_KLEIDIAI' flag instead.")
+else()
+ set_kleidiai_flag()
+ set(onnxruntime_USE_KLEIDIAI ${USE_KLEIDIAI})
+endif()
+
+# --------------------------------------------------------------------
+# onnxruntime build flags
+# onnxruntime_BUILD_SHARED_LIB needs to be ON to produce libonnxruntime.so
+# which is a dependency for libonnruntime-genai.so
+# --------------------------------------------------------------------
+
+set(onnxruntime_BUILD_UNIT_TESTS OFF)
+set(onnxruntime_ENABLE_TESTS OFF)
+set(onnxruntime_DISABLE_RTTI ON)
+set(onnxruntime_DISABLE_FLOAT8_TYPES ON)
+set(onnxruntime_ENABLE_EXTERNAL_CUSTOM_OP_SCHEMAS OFF)
+set(onnxruntime_USE_MIMALLOC OFF)
+set(onnxruntime_BUILD_SHARED_LIB ON)
+set(ONNX_BUILD_SHARED_LIBS OFF)
+
+set(ONNXRUNTIME_GIT_URL "https://github.com/microsoft/onnxruntime.git"
+ CACHE STRING "Git URL for onnxruntime repo")
+
+set(ONNXRUNTIME_GIT_TAG "v1.22.0"
+ CACHE STRING "Path where onnxruntime repo should be cloned")
+
+set(ONNXRUNTIME_SRC_DIR "${CMAKE_BINARY_DIR}/onnxruntime"
+ CACHE PATH "Source dir")
+
+set(ONNXRUNTIME_BUILD_DIR "${CMAKE_BINARY_DIR}/ort_build"
+ CACHE PATH "Onnxruntime build dir")
+
+file(MAKE_DIRECTORY "${ONNXRUNTIME_BUILD_DIR}")
+
+FetchContent_Declare(
+ onnxruntime
+ GIT_REPOSITORY ${ONNXRUNTIME_GIT_URL}
+ GIT_TAG ${ONNXRUNTIME_GIT_TAG}
+ SOURCE_DIR ${ONNXRUNTIME_SRC_DIR}
+ GIT_SHALLOW 1
+)
+
+FetchContent_MakeAvailable(onnxruntime)
+
+add_subdirectory(
+ ${ONNXRUNTIME_SRC_DIR}/cmake
+ ${ONNXRUNTIME_BUILD_DIR})
+
+# ---------------------------------------------------------------------------
+# Temporary workaround:
+#
+# The onnxruntime-genai build expects the onnxruntime header files and the shared
+# library to reside in a very specific folder structure at configuration
+# time.
+#
+# To build onnxruntime from source (avoiding the pre-built download), we create a
+# temporary staging directory that:
+# - Copies in the onnxruntime headers required by onnxruntime-genai, and
+# - Drops in a placeholder onnxruntime shared library just long enough for
+# onnxruntime-genai to detect it (the file is removed immediately afterward).
+#
+# The actual onnxruntime shared library produced by the build is written to
+# ${CMAKE_BINARY_DIR}/lib.
+# ---------------------------------------------------------------------------
+
+set(ORT_HOME "${ONNXRUNTIME_BUILD_DIR}/ort-tmp-src" CACHE STRING "Ort path")
+
+if(ANDROID)
+ set(ORT_INCLUDE_DIR "${ORT_HOME}/headers")
+ set(ORT_LIB_DIR "${ORT_HOME}/jni/${ANDROID_ABI}")
+else()
+ set(ORT_INCLUDE_DIR "${ORT_HOME}/include")
+ set(ORT_LIB_DIR "${ORT_HOME}/lib")
+endif()
+
+file(MAKE_DIRECTORY "${ORT_INCLUDE_DIR}")
+file(MAKE_DIRECTORY "${ORT_LIB_DIR}")
+
+set(ORT_HEADERS_DIR "${ONNXRUNTIME_SRC_DIR}/include/onnxruntime/core/session/")
+
+file(COPY ${ORT_HEADERS_DIR} DESTINATION ${ORT_INCLUDE_DIR})
+file(TOUCH ${ORT_LIB_DIR}/libonnxruntime.so "")
+file(TOUCH ${ORT_LIB_DIR}/libonnxruntime.dylib "")
+
+# --------------------------------------------------------------------
+# onnxruntime-genai build
+# --------------------------------------------------------------------
+
+set(ONNXRT_GENAI_SRC_DIR "${CMAKE_BINARY_DIR}/onnxruntime-genai"
+ CACHE PATH "Path where onnxruntime-genai repo should be cloned")
+
+set(ONNXRT_GENAI_GIT_URL "https://github.com/microsoft/onnxruntime-genai.git"
+ CACHE STRING "Git URL for onnxruntime-genai repo")
+
+# Latest stable tag (June 2025)
+set(ONNXRT_GENAI_GIT_TAG "v0.8.2"
+ CACHE STRING "Git tag / commit SHA for onnxruntime-genai repo")
+
+# --------------------------------------------------------------------
+# onnxruntime-genai build flags
+# Disable GPU execution providers unless the caller turns them back on
+# (e.g. cmake .. -DUSE_CUDA=ON).
+# --------------------------------------------------------------------
+
+set(USE_CUDA OFF)
+set(USE_ROCM OFF)
+set(USE_DML OFF)
+set(ENABLE_PYTHON OFF)
+set(ENABLE_JAVA OFF)
+set(ENABLE_TESTS OFF)
+
+# model-benchmark standalone application
+set(ENABLE_MODEL_BENCHMARK ${BUILD_EXECUTABLE} CACHE BOOL "Enable model benchmark binary")
+
+# Fetch the dependency Git repo here
+FetchContent_Declare(onnxruntime-genai
+ GIT_REPOSITORY ${ONNXRT_GENAI_GIT_URL}
+ GIT_TAG ${ONNXRT_GENAI_GIT_TAG}
+ GIT_SHALLOW 1
+ SOURCE_DIR ${ONNXRT_GENAI_SRC_DIR})
+
+FetchContent_MakeAvailable(onnxruntime-genai)
+
+# Remove the created onnxruntime shared library
+file(REMOVE_RECURSE ${ORT_LIB_DIR})
+
+# Check internally defined dependency targets are visible
+if (NOT TARGET arm-llm-config)
+ message(FATAL_ERROR "arm-llm-config target not defined")
+elseif (NOT TARGET arm-llm-interface)
+ message(FATAL_ERROR "arm-llm-interface target not defined")
+endif()
+
+add_library(arm-llm-framework STATIC
+ ${CMAKE_CURRENT_SOURCE_DIR}/OnnxrtImpl.cpp)
+
+target_include_directories(arm-llm-framework PUBLIC
+ ${CMAKE_CURRENT_SOURCE_DIR}
+ ${ONNXRT_GENAI_SRC_DIR}/src)
+
+# List all libraries that we need to depend on here:
+target_link_libraries(arm-llm-framework PUBLIC
+ onnxruntime-genai
+ onnxruntime
+ arm-llm-config
+ arm-llm-interface)
+
+if(ENABLE_MODEL_BENCHMARK)
+ target_include_directories(model_benchmark PRIVATE
+ ${CMAKE_CURRENT_SOURCE_DIR}
+ ${ONNXRT_GENAI_SRC_DIR}/src)
+
+ target_link_libraries(model_benchmark PRIVATE onnxruntime-genai ${ONNXRUNTIME_LIB})
+ target_link_directories(model_benchmark PRIVATE ${CMAKE_LIBRARY_OUTPUT_DIRECTORY})
+ add_dependencies(model_benchmark onnxruntime)
+endif()
diff --git a/src/cpp/frameworks/onnxruntime_genai/LlmImpl.hpp b/src/cpp/frameworks/onnxruntime_genai/LlmImpl.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..0588f1345281582b317a4d04c0158bebb39cd6d8
--- /dev/null
+++ b/src/cpp/frameworks/onnxruntime_genai/LlmImpl.hpp
@@ -0,0 +1,204 @@
+//
+// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+#ifndef LLM_IMPL_HPP
+#define LLM_IMPL_HPP
+
+#include "Llm.hpp"
+#include "LlmConfig.hpp"
+
+#include "ort_genai.h"
+
+/* Forward declaration */
+class LLM;
+
+/**
+ * @brief ONNX Implementation of our LLM API
+ */
+class LLM::LLMImpl {
+
+public:
+ LLMImpl();
+ ~LLMImpl();
+
+ /**
+ * Method to initialize a ONNX model
+ * @param config Configuration class with model's parameter and user defined parameters
+ */
+ void LlmInit(const LlmConfig& config);
+
+ /**
+ * Method to free all allocations pertaining to ONNX model
+ */
+ void FreeLlm();
+
+ /**
+ * Function to retrieve the ONNX encode timings.
+ * @return The encoded tokens per second
+ */
+ float GetEncodeTimings();
+
+ /**
+ * Function to retrieve the ONNX decode timings.
+ * @return The decoded tokens per second
+ */
+ float GetDecodeTimings();
+
+ /**
+ * Function to reset the ONNX timing
+ */
+ void ResetTimings();
+
+ /**
+ * Function to print the system info
+ * @return System info as a char pointer
+ */
+ std::string SystemInfo();
+
+ /**
+ * Method to reset conversation history and preserve model's character prefix.
+ * If model's prefix is not defined all conversation history would be cleared
+ */
+ void ResetContext();
+
+ /**
+ * Method to prompt encoding
+ * @param prompt Query to LLM
+ */
+ void Encode(std::string& prompt);
+
+ /**
+ * Method to produce next token
+ * @return the next token for encoded prompt
+ */
+ std::string NextToken();
+
+ /**
+ * The method return the percentage of chat context filled
+ * @return chat capacity filled in cache as percentage number
+ */
+ size_t GetChatProgress() const;
+
+ /**
+ * Benchmarks the performance of the LLM model.
+ *
+ * This function evaluates the model's performance by processing a specified number of prompts
+ * and generating text sequences. It measures the speed of prompt evaluation and text
+ * generation, calculates average speeds and standard deviations over multiple repetitions, and
+ * compiles the results into a formatted string.
+ *
+ * @param prompts Number of prompts to process during benchmarking.
+ * @param eval_prompts Number of evaluation prompts for text generation.
+ * @param n_max_sq Maximum sequence length for text generation.
+ * @param n_rep Number of repetitions for benchmarking to obtain average metrics.
+ * @return A formatted string containing the benchmark results, including model description,
+ * size, number of parameters, backend information, and performance metrics for prompt
+ * evaluation and text generation.
+ */
+ std::string BenchModel(int& prompts, int& eval_prompts, int& n_max_sq, int& n_rep);
+
+ /**
+ * Method to get framework type
+ * @return string framework type
+ */
+ std::string GetFrameworkType();
+
+private:
+ // Framework type
+ std::string m_frameworkType{"onnxruntime-genai"};
+ // Pointer to the loaded OgaModel used for inference
+ std::unique_ptr m_llmModelPtr {nullptr};
+ // Pointer to the OgaConfig instance containing model configuration settings.
+ std::unique_ptr m_llmConfigsPtr {nullptr};
+ // Pointer to the OgaGeneratorParams instance holding generation parameters.
+ std::unique_ptr m_llmGntParamsPtr {nullptr};
+ // Pointer to the OgaGenerator object responsible for text generation.
+ std::unique_ptr m_llmGeneratorPtr {nullptr};
+ // Pointer to the OgaTokenizer used for tokenizing input text.
+ std::unique_ptr m_tokenizerPtr {nullptr};
+ // Pointer to the OgaTokenizerStream used for streaming tokenized outputs.
+ std::unique_ptr m_tokenizerStreamPtr {nullptr};
+ // Pointer to the OgaSequences container storing generated token sequences.
+ std::unique_ptr m_sequencesPtr {nullptr};
+
+ // Number of threads to use for model inference.
+ size_t m_numOfThreads{0};
+ // Maximum context length (number of tokens) supported by the model.
+ int m_nCtx{2048};
+ // Batch size for token generation operations.
+ size_t m_batchSz{0};
+ // Filesystem path to the ONNX model.
+ std::string m_modelPath{""};
+ // Indicates whether the LLM has been initialized.
+ bool m_llmInitialized{false};
+ // Number of tokens currently filled in the context window
+ size_t m_contextFilled{0};
+ // Prefix text prepended to each generation request.
+ std::string m_llmPrefix{""};
+ // Flag indicating if the context window has been reset.
+ bool m_ctxResetted = false;
+ // Total number of decoded tokens
+ size_t m_totalDecodedTokens = 0;
+ // Total number of encoded tokens
+ size_t m_totalEncodedTokens = 0;
+ // Total time for decoder
+ double m_totalDecoderTime = 0.0;
+ // Total time for encoder
+ double m_totalEncoderTime = 0.0;
+
+
+
+ /**
+ * Function to initialize the LLM model sequence
+ */
+ void InitSequence();
+
+ /**
+ * Frees the memory holding the LLM Model sequence
+ */
+ void FreeSequence();
+
+ /**
+ * Function to initialize the LLM model configs
+ */
+ void InitConfigs();
+
+ /**
+ * Frees the memory holding the configs
+ */
+ void FreeConfigs();
+
+ /**
+ * Function to initialize a new generator
+ */
+ void InitGenerator();
+
+ /**
+ * Frees the memory holding the generator
+ */
+ void FreeGenerator();
+
+ /**
+ * Function to initialize a new tokenizer
+ */
+ void InitTokenizer();
+
+ /**
+ * Frees the memory holding the tokenizer
+ */
+ void FreeTokenizer();
+
+ /**
+ * Function to load the chosen ONNX model to memory
+ */
+ void LoadModel();
+
+ /**
+ * Frees the memory holding the ONNX model
+ */
+ void FreeModel();
+};
+
+#endif /* LLM_IMPL_HPP */
diff --git a/src/cpp/frameworks/onnxruntime_genai/OnnxrtImpl.cpp b/src/cpp/frameworks/onnxruntime_genai/OnnxrtImpl.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..e4bcc9d67874ad8f4d4f042bba2fe6e5d2e202e3
--- /dev/null
+++ b/src/cpp/frameworks/onnxruntime_genai/OnnxrtImpl.cpp
@@ -0,0 +1,333 @@
+//
+// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+#include "LlmImpl.hpp"
+#include
+
+#define LOG_INF(...) \
+ do { \
+ fprintf(stdout, __VA_ARGS__); \
+ } while (0)
+
+using Clock = std::chrono::high_resolution_clock;
+using TimePoint = std::chrono::time_point;
+using Duration = std::chrono::duration;
+
+/**
+ * @brief ONNX Implementation of our LLM API
+ *
+ */
+LLM::LLMImpl::LLMImpl() {}
+
+LLM::LLMImpl::~LLMImpl()
+{
+ this->FreeLlm();
+}
+
+void LLM::LLMImpl::InitSequence()
+{
+ this->m_sequencesPtr = OgaSequences::Create();
+
+ if (this->m_sequencesPtr == nullptr) {
+ throw std::runtime_error("Error: unable to init sequence");
+ }
+
+ LOG_INF("Seuqence Initialized\n");
+}
+
+void LLM::LLMImpl::FreeSequence()
+{
+ if (this->m_sequencesPtr) {
+ this->m_sequencesPtr.reset();
+ this->m_sequencesPtr = nullptr;
+ LOG_INF("Freed Sequences\n");
+ }
+}
+
+void LLM::LLMImpl::InitConfigs()
+{
+ // genai_config.json path (same as model path)
+ this->m_llmConfigsPtr = OgaConfig::Create(this->m_modelPath.c_str());
+
+ if (this->m_llmConfigsPtr == nullptr) {
+ throw std::runtime_error("Error: configs initialization failed");
+ }
+
+ // This will fall back to default provider which is: CPU
+ this->m_llmConfigsPtr->ClearProviders();
+
+ // Currently we modify only thread numbers, but different session options can be modified
+ // Ref: https://github.com/microsoft/onnxruntime-genai/blob/79d1d8470b74564fc4e723312a476e692057b600/src/config.h#L64
+ std::string patch =
+ std::string(R"json({
+ "model": {
+ "decoder": {
+ "session_options": {
+ "intra_op_num_threads": )json")
+ + std::to_string(this->m_numOfThreads)
+ + R"json(
+ }
+ }
+ }
+ })json";
+
+ this->m_llmConfigsPtr->Overlay(patch.c_str());
+ LOG_INF("Configs Initialized\n");
+}
+
+void LLM::LLMImpl::FreeConfigs()
+{
+ if (this->m_llmConfigsPtr) {
+ this->m_llmConfigsPtr.reset();
+ this->m_llmConfigsPtr = nullptr;
+ LOG_INF("Freed Configs\n");
+ }
+}
+
+void LLM::LLMImpl::InitGenerator()
+{
+ this->m_llmGntParamsPtr = OgaGeneratorParams::Create(* this->m_llmModelPtr);
+
+ if (this->m_llmGntParamsPtr == nullptr) {
+ throw std::runtime_error("Error: generator params initialization failed");
+ }
+
+ this->m_llmGntParamsPtr->SetSearchOption("max_length", this->m_nCtx);
+
+ this->m_llmGeneratorPtr = OgaGenerator::Create(* this->m_llmModelPtr, * this->m_llmGntParamsPtr);
+
+ if (this->m_llmGeneratorPtr == nullptr) {
+ throw std::runtime_error("Error: generator initialization failed. Unable to create ONNX generator");
+ }
+
+ LOG_INF("Generator Initialized\n");
+}
+
+void LLM::LLMImpl::FreeGenerator()
+{
+ if (this->m_llmGeneratorPtr) {
+ this->m_llmGntParamsPtr.reset();
+ this->m_llmGntParamsPtr = nullptr;
+
+ this->m_llmGeneratorPtr.reset();
+ this->m_llmGeneratorPtr = nullptr;
+ LOG_INF("Freed Generator\n");
+ }
+}
+
+void LLM::LLMImpl::InitTokenizer()
+{
+ this->m_tokenizerPtr = OgaTokenizer::Create(*this->m_llmModelPtr);
+
+ if (this->m_tokenizerPtr == nullptr) {
+ throw std::runtime_error("Error: tokenizer initialization failed");
+ }
+
+ this->m_tokenizerStreamPtr = OgaTokenizerStream::Create(*this->m_tokenizerPtr);
+
+ if (this->m_tokenizerStreamPtr == nullptr) {
+ throw std::runtime_error("Error: tokenizer stream initialization failed");
+ }
+
+ LOG_INF("Tokenizer Initialized\n");
+ LOG_INF("Tokenizer Stream Initialized\n");
+}
+
+void LLM::LLMImpl::FreeTokenizer()
+{
+ if (this->m_tokenizerPtr) {
+ this->m_tokenizerPtr.reset();
+ this->m_tokenizerPtr = nullptr;
+
+ this->m_tokenizerStreamPtr.reset();
+ this->m_tokenizerStreamPtr = nullptr;
+ LOG_INF("Freed Tokenizer\n");
+ }
+}
+
+void LLM::LLMImpl::LoadModel()
+{
+ this->m_llmModelPtr = OgaModel::Create(* this->m_llmConfigsPtr);
+
+ if (this->m_llmModelPtr == nullptr) {
+ throw std::runtime_error("Error: unable to load model from " + std::string(this->m_modelPath));
+ }
+
+ LOG_INF("Model Loaded\n");
+}
+
+void LLM::LLMImpl::FreeModel()
+{
+ if (this->m_llmModelPtr) {
+ this->m_llmModelPtr.reset();
+ this->m_llmModelPtr = nullptr;
+ LOG_INF("Freed Model\n");
+ }
+
+ this->m_llmInitialized = false;
+}
+
+void LLM::LLMImpl::LlmInit(const LlmConfig& config)
+{
+ try {
+ this->m_batchSz = config.GetBatchSize();
+ this->m_numOfThreads = config.GetNumThreads();
+ this->m_modelPath = config.GetModelPath().c_str();
+ this->m_llmPrefix = config.GetLlmPrefix();
+
+ InitConfigs();
+
+ if (this->m_llmConfigsPtr != nullptr) {
+ LoadModel();
+ }
+ else {
+ LOG_INF("Config is not initialized\n");
+ }
+
+ if (this->m_llmModelPtr != nullptr) {
+ InitTokenizer();
+ InitGenerator();
+ }
+
+ else {
+ LOG_INF("Model is not loaded\n");
+ }
+
+ if (this->m_llmConfigsPtr != nullptr &&
+ this->m_tokenizerStreamPtr != nullptr &&
+ this->m_llmGeneratorPtr != nullptr) {
+
+ this->m_llmInitialized = true;
+ }
+
+ else {
+ this->m_llmInitialized = false;
+ }
+
+ } catch (const std::exception& e) {
+ throw std::runtime_error("LLM initialization failed: " + std::string(e.what()));
+ }
+
+ LOG_INF("LLM Initialized\n");
+}
+
+void LLM::LLMImpl::FreeLlm()
+{
+ if (this->m_llmInitialized) {
+ FreeConfigs();
+ FreeModel();
+ FreeGenerator();
+ FreeTokenizer();
+ FreeSequence();
+ ResetTimings();
+ this->m_llmInitialized = false;
+ LOG_INF("Freed Entire LLM\n");
+ }
+}
+
+void LLM::LLMImpl::ResetContext()
+{
+ this->m_llmGeneratorPtr->RewindTo(0);
+ this->m_ctxResetted = true;
+ LOG_INF("Reset Context\n");
+}
+
+void LLM::LLMImpl::Encode(std::string& prompt)
+{
+ if (this->m_ctxResetted) {
+ prompt = this->m_llmPrefix + prompt;
+ this->m_ctxResetted = false;
+ }
+
+ // Time start
+ TimePoint startTimeStampEncoder = Clock::now();
+
+ InitSequence();
+
+ this->m_tokenizerPtr->Encode(prompt.c_str(), * this->m_sequencesPtr);
+ this->m_llmGeneratorPtr->AppendTokenSequences(* this->m_sequencesPtr);
+
+ // Record finishing time
+ this->m_totalEncoderTime += Duration(Clock::now() - startTimeStampEncoder).count();
+ this->m_totalEncodedTokens += this->m_sequencesPtr->SequenceCount(0);
+
+}
+
+std::string LLM::LLMImpl::NextToken()
+{
+ if(!this->m_llmGeneratorPtr->IsDone()) {
+ // Record starting time
+ TimePoint startTimeStampDecoder = Clock::now();
+
+ this->m_llmGeneratorPtr->GenerateNextToken();
+ size_t cnt = this->m_llmGeneratorPtr->GetSequenceCount(0);
+ int32_t tok = this->m_llmGeneratorPtr->GetSequenceData(0)[cnt - 1];
+ auto out = this->m_tokenizerStreamPtr->Decode(tok);
+
+ // Record finishing time
+ this->m_totalDecoderTime += Duration(Clock::now() - startTimeStampDecoder).count();
+ this->m_totalDecodedTokens += 1;
+
+ size_t nCurr = this->m_llmGeneratorPtr->GetSequenceCount(0);
+
+ this->m_contextFilled = 100 * nCurr / this->m_nCtx;
+
+ return out;
+ }
+
+ else {
+ return "<|endoftext|>";
+ }
+}
+
+size_t LLM::LLMImpl::GetChatProgress() const
+{
+ return this->m_contextFilled;
+}
+
+float LLM::LLMImpl::GetEncodeTimings()
+{
+ auto encoderTPS = this->m_totalEncodedTokens / this->m_totalEncoderTime;
+ return encoderTPS;
+}
+
+float LLM::LLMImpl::GetDecodeTimings()
+{
+ auto decoderTPS = this->m_totalDecodedTokens / this->m_totalDecoderTime;
+ return decoderTPS;
+}
+
+void LLM::LLMImpl::ResetTimings()
+{
+ this->m_totalDecoderTime = 0;
+ this->m_totalEncoderTime = 0;
+ this->m_totalDecodedTokens = 0;
+ this->m_totalEncodedTokens = 0;
+ LOG_INF("Reset Timings\n");
+
+}
+
+std::string LLM::LLMImpl::SystemInfo()
+{
+ std::string sysInfo = "\nSystem INFO:\n";
+ std::string deviceType = std::string(this->m_llmModelPtr->GetDeviceType());
+ std::string modelType = std::string(this->m_llmModelPtr->GetType());
+ sysInfo += "Device Type: " + deviceType + "\n";
+ sysInfo += "Model Type: " + modelType + "\n";
+ return sysInfo;
+}
+
+std::string LLM::LLMImpl::BenchModel(int& prompts, int& eval_prompts, int& n_max_sq, int& n_rep)
+{
+ // TODO: Refactor BenchModel() into a framework-agnostic utility:
+ // Abstract the core benchmarking logic into a shared BenchModel(const Config&) function,
+ // Migrate each framework submodule to invoke it, and consolidate all parameters into the Config struct.
+ return (char *) nullptr;
+}
+
+std::string LLM::LLMImpl::GetFrameworkType()
+{
+ return this->m_frameworkType;
+}
diff --git a/src/cpp/interface/Llm.hpp b/src/cpp/interface/Llm.hpp
index 5355551d3f51f3a4056568ee1828a281533f9041..f7a5a22668cf0175303fe496aa27bf0a66762e92 100644
--- a/src/cpp/interface/Llm.hpp
+++ b/src/cpp/interface/Llm.hpp
@@ -91,6 +91,12 @@ public:
*/
std::string BenchModel(int& nPrompts, int& nEvalPrompts, int& nMaxSeq, int& nRep);
+
+ /**
+ * Method to get framework type
+ * @return string framework type
+ */
+ std::string GetFrameworkType();
};
#endif /* ARM_LLM_HPP */
diff --git a/src/java/com/arm/Llm.java b/src/java/com/arm/Llm.java
index 5c87432e04256189685320b97ff1202eae2a2e57..a994487553d9c04b728a8b568d3c69a72185b94c 100644
--- a/src/java/com/arm/Llm.java
+++ b/src/java/com/arm/Llm.java
@@ -31,6 +31,7 @@ public class Llm extends SubmissionPublisher
private long llmPtr = 0;
private String modelTag = "";
private String userTag = "";
+ private String endTag = "";
private List stopWords = null;
private String cachedToken = "";
private String emitToken = "";
@@ -43,14 +44,17 @@ public class Llm extends SubmissionPublisher
/**
* Method to create LlmConfig cpp instance from params.
* @param modelTag name used to refer the model
+ * @param userTag tag used to refer the user
+ * @param endTag tag to specify the end of the query
* @param modelPath path to load model from
* @param llmPrefix Initial-prompt to load into llm before query
* @param numThreads Number of threads for inference
* @param batchSize batch size used to chunk queries
* @return pointer to llm config
*/
- public native long createLlmConfig(String modelTag, String modelPath, String llmPrefix,
- int numThreads, int batchSize);
+ public native long createLlmConfig(String modelTag, String userTag, String endTag,
+ String modelPath, String llmPrefix, int numThreads,
+ int batchSize);
/**
* Method for loading LLM model
* @param LlmConfig load model from LlmConfig
@@ -107,7 +111,6 @@ public class Llm extends SubmissionPublisher
/**
* Method to decode answers one by one, once prefill stage is completed
- *
* @param nPrompts prompt length used for benchmarking
* @param nEvalPrompts number of generated tokens for benchmarking
* @param nMaxSeq sequence number
@@ -121,6 +124,12 @@ public class Llm extends SubmissionPublisher
int nRep
);
+ /**
+ * Method to get framework type
+ * @return string framework type
+ */
+ public native String getFrameworkType();
+
/**
* Method to separate Initialization from constructor
* @param llmConfig type configuration file to load model
@@ -130,10 +139,12 @@ public class Llm extends SubmissionPublisher
this.stopWords = llmConfig.getStopWords();
this.modelTag = llmConfig.getModelTag();
this.userTag = llmConfig.getUserTag();
+ this.endTag = llmConfig.getEndTag();
this.llmPrefix = llmConfig.getLlmPrefix();
this.numThreads = llmConfig.getNumThreads();
- long configPtr = createLlmConfig(this.modelTag,llmConfig.getModelPath(),
- this.llmPrefix,this.numThreads,this.batchSize);
+ long configPtr = createLlmConfig(this.modelTag, this.userTag, this.endTag,
+ llmConfig.getModelPath(), this.llmPrefix,
+ this.numThreads, this.batchSize);
this.llmPtr = loadModel(configPtr);
}
@@ -143,7 +154,6 @@ public class Llm extends SubmissionPublisher
*/
public void setSubscriber(Flow.Subscriber subscriber)
{
- System.out.println("subscribed set from llama");
this.subscribe(subscriber);
}
@@ -156,9 +166,9 @@ public class Llm extends SubmissionPublisher
String query = "";
AtomicBoolean stop = new AtomicBoolean(false);
if (evaluatedOnce.get())
- query = userTag + Query + modelTag;
+ query = userTag + Query + endTag + modelTag;
else
- query = llmPrefix + Query + modelTag;
+ query = llmPrefix + userTag + Query + endTag + modelTag;
encode(query);
evaluatedOnce.set(true);
while (getChatProgress()<100)
@@ -189,9 +199,9 @@ public class Llm extends SubmissionPublisher
String query = "";
boolean stop = false;
if (evaluatedOnce.get())
- query = userTag + Query + modelTag;
+ query = userTag + Query + endTag + modelTag;
else
- query = llmPrefix + Query + modelTag;
+ query = llmPrefix + userTag + Query + endTag + modelTag;
encode(query);
evaluatedOnce.set(true);
while (getChatProgress()<100)
diff --git a/src/java/com/arm/LlmConfig.java b/src/java/com/arm/LlmConfig.java
index 169b350d847d2245215e66f87763237647f7ccbc..1897bbd49abd5ecff7b52daf96e249eaf9d4881c 100644
--- a/src/java/com/arm/LlmConfig.java
+++ b/src/java/com/arm/LlmConfig.java
@@ -15,13 +15,14 @@ public class LlmConfig
{
private String modelTag;
private String userTag;
+ private String endTag;
private String modelPath;
private String llmPrefix;
private List stopWords;
private int numThreads;
/**
- * Minimal constructor without userTag and numThreads
+ * Minimal constructor without userTag, endTag and numThreads
*
* @param modelTag tag for the model
* @param stopWords stop words to use
@@ -30,7 +31,7 @@ public class LlmConfig
*/
public LlmConfig(String modelTag, List stopWords, String modelPath, String llmPrefix)
{
- this(modelTag, stopWords, modelPath, llmPrefix, "", 4);
+ this(modelTag, stopWords, modelPath, llmPrefix, "", "", 4);
}
/**
@@ -41,15 +42,16 @@ public class LlmConfig
* @param modelPath path to the model
* @param llmPrefix llm prefix to use
* @param userTag user tag to use
+ * @param endTag end tag to use
*/
- public LlmConfig(String modelTag, List stopWords, String modelPath, String llmPrefix, String userTag)
+ public LlmConfig(String modelTag, List stopWords, String modelPath, String llmPrefix, String userTag, String endTag)
{
// Use 4 threads by default
- this(modelTag, stopWords, modelPath, llmPrefix, userTag, 4);
+ this(modelTag, stopWords, modelPath, llmPrefix, userTag, endTag, 4);
}
/**
- * Minimal constructor without userTag
+ * Minimal constructor without userTag, and endTag
*
* @param modelTag tag for the model
* @param stopWords stop words to use
@@ -57,9 +59,9 @@ public class LlmConfig
* @param llmPrefix llm prefix to use
* @param numThreads number of threads to use
*/
- public LlmConfig(String modelTag, List stopWords, String modelPath, String llmPrefix,int numThreads)
+ public LlmConfig(String modelTag, List stopWords, String modelPath, String llmPrefix, int numThreads)
{
- this(modelTag, stopWords, modelPath, llmPrefix, "", numThreads);
+ this(modelTag, stopWords, modelPath, llmPrefix, "", "", numThreads);
}
/**
@@ -70,16 +72,18 @@ public class LlmConfig
* @param modelPath path to the model
* @param llmPrefix llm prefix to use
* @param userTag user tag to use
+ * @param endTag end tag to use
* @param numThreads number of threads to use
*/
public LlmConfig(String modelTag, List stopWords, String modelPath,
- String llmPrefix, String userTag, int numThreads)
+ String llmPrefix, String userTag, String endTag, int numThreads)
{
this.modelTag = modelTag;
this.stopWords = stopWords;
this.modelPath = modelPath;
this.llmPrefix = llmPrefix;
this.userTag = userTag;
+ this.endTag = endTag;
this.numThreads = numThreads;
}
@@ -101,6 +105,15 @@ public class LlmConfig
{
return this.userTag;
}
+ /**
+ * Gets the end tag.
+ *
+ * @return The end tag.
+ */
+ public String getEndTag()
+ {
+ return this.endTag;
+ }
/**
* Gets the list of stop words.
@@ -161,6 +174,16 @@ public class LlmConfig
this.userTag = userTag;
}
+ /**
+ * Sets the end tag.
+ *
+ * @param endTag The end tag to set.
+ */
+ public void setEndTag(String endTag)
+ {
+ this.endTag = endTag;
+ }
+
/**
* Sets the list of stop words.
*
@@ -192,10 +215,10 @@ public class LlmConfig
}
/**
- * Sets the number of Threads.
- * @param numThreads count of threads to use for LLM.
- */
-
+ * Sets the number of Threads.
+ *
+ * @param numThreads count of threads to use for LLM.
+ */
public void setNumThreads(int numThreads)
{
this.numThreads = numThreads;
diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt
index 54886b2ba318740457be040411f5020776ed55e9..31f2099c34a66db636fc3187439a021443ea58aa 100644
--- a/test/CMakeLists.txt
+++ b/test/CMakeLists.txt
@@ -25,8 +25,10 @@ add_executable(llm-cpp-tests
${CMAKE_CURRENT_SOURCE_DIR}/cpp/LlmUtils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/LlmTest.cpp)
-if (${LLM_DEP_NAME} STREQUAL "llama.cpp")
+if (${LLM_FRAMEWORK} STREQUAL "llama.cpp")
set(CONFIG_FILE_NAME "llamaConfig.txt" CACHE STRING "Path to the Llama config file")
+elseif (${LLM_FRAMEWORK} STREQUAL "onnxruntime-genai")
+ set(CONFIG_FILE_NAME "onnxrtConfig.txt" CACHE STRING "Path to the ONNX Runtime GenAI config file")
endif ()
diff --git a/test/cpp/LlmTest.cpp b/test/cpp/LlmTest.cpp
index d0f520bc66ac95ddf272ce8417ac71fcd5a083aa..baaa897a1c4bfce23bc67f54e59344ac0f6002f4 100644
--- a/test/cpp/LlmTest.cpp
+++ b/test/cpp/LlmTest.cpp
@@ -49,7 +49,8 @@ TEST_CASE("Test Llm-Wrapper class")
SetupTestConfig(stopWordsStream, &configTest, STOP_WORDS);
std::string response;
- std::string question = "What is the capital of France?" + configTest.GetModelTag();
+ std::string question = configTest.GetUserTag() +"What is the capital of France?" +
+ configTest.GetEndTag() + configTest.GetModelTag();
std::string prefixedQuestion = configTest.GetLlmPrefix() + question;
LLM llm;
diff --git a/test/cpp/LlmUtils.cpp b/test/cpp/LlmUtils.cpp
index 3d90bae37d704f071464539be6de310a45e8eb0e..96acd00cbf5353e47f52722fe89610317aaa63d1 100644
--- a/test/cpp/LlmUtils.cpp
+++ b/test/cpp/LlmUtils.cpp
@@ -79,6 +79,10 @@ LlmConfig GetConfig(std::unordered_map config,
throw std::runtime_error("Missing required parameter: modelPath");
if (config.find("modelTag") == config.end())
throw std::runtime_error("Missing required parameter: modelTag");
+ if (config.find("userTag") == config.end())
+ throw std::runtime_error("Missing required parameter: userTag");
+ if (config.find("endTag") == config.end())
+ throw std::runtime_error("Missing required parameter: endTag");
if (config.find("llmPrefix") == config.end())
throw std::runtime_error("Missing required parameter: llmPrefix");
@@ -90,6 +94,8 @@ LlmConfig GetConfig(std::unordered_map config,
throw std::runtime_error("Missing required parameter: stopWords");
return LlmConfig(config.at("modelTag"),
+ config.at("userTag"),
+ config.at("endTag"),
config.at("modelPath"),
config.at("llmPrefix"),
userConfig.at("numThreads"),
diff --git a/test/java/com/arm/LlmTestJNI.java b/test/java/com/arm/LlmTestJNI.java
index f1e5f4a0d8602efa871b1fdd0f20c9bb57fcd48c..4886579d8347b78e7d2903a4a064b31318a45ba1 100644
--- a/test/java/com/arm/LlmTestJNI.java
+++ b/test/java/com/arm/LlmTestJNI.java
@@ -27,6 +27,7 @@ public class LlmTestJNI {
private static int numThreads = 4;
private static String modelTag = "";
private static String userTag = "";
+ private static String endTag = "";
private static String modelPath = "";
private static String llmPrefix = "";
private static List stopWords = new ArrayList();
@@ -80,6 +81,7 @@ public class LlmTestJNI {
loadVariables(configFilePath);
modelTag = variables.get("modelTag");
userTag = variables.getOrDefault("userTag","");
+ endTag = variables.getOrDefault("endTag", "");
llmPrefix = variables.get("llmPrefix");
modelPath = modelDir + "/" + variables.get("llmModelName");
loadVariables(userConfigFilePath);
@@ -95,7 +97,7 @@ public class LlmTestJNI {
@Test
public void testConfigLoading() {
- LlmConfig llmConfig = new LlmConfig(modelTag, stopWords, modelPath, llmPrefix,userTag,numThreads);
+ LlmConfig llmConfig = new LlmConfig(modelTag, stopWords, modelPath, llmPrefix, userTag, endTag, numThreads);
assertTrue("Model tag is not empty", !llmConfig.getModelTag().isEmpty());
assertTrue("LLM prefix is not empty", !llmConfig.getLlmPrefix().isEmpty());
assertTrue("Stop words list is not empty", !llmConfig.getStopWords().isEmpty());
@@ -103,11 +105,11 @@ public class LlmTestJNI {
@Test
public void testLlmPrefixSetting() {
- LlmConfig llmConfig = new LlmConfig(modelTag, stopWords, modelPath, llmPrefix,userTag);
+ LlmConfig llmConfig = new LlmConfig(modelTag, stopWords, modelPath, llmPrefix, userTag, endTag, numThreads);
Llm llm = new Llm();
llm.llmInit(llmConfig);
- String newModelTag = ("Ferdia");
+ String newModelTag = ("Ferdia:");
String newPrefix = "Transcript of a dialog, where the User interacts with an AI Assistant named " + newModelTag +
". " + newModelTag +
" is helpful, polite, honest, good at writing and answers honestly with a maximum of two sentences. User:";
@@ -123,27 +125,26 @@ public class LlmTestJNI {
@Test
public void testInferenceWithContextReset() {
- LlmConfig llmConfig = new LlmConfig(modelTag, stopWords, modelPath, llmPrefix,userTag,numThreads);
- Llm llm = new Llm();
- llm.llmInit(llmConfig);
-
- String question1 = "What is the capital of the country, Morocco?";
- String response1 = llm.send(question1);
- checkLlmMatch(response1, "Rabat", true);
+ LlmConfig llmConfig = new LlmConfig(modelTag, stopWords, modelPath, llmPrefix, userTag, endTag, numThreads);
+ Llm llm = new Llm();
+ llm.llmInit(llmConfig);
- // Resetting context should cause model to forget what country is being referred to
- llm.resetContext();
+ String question1 = "What is the capital of the country, Morocco?";
+ String response1 = llm.send(question1);
+ checkLlmMatch(response1, "Rabat", true);
- String question2 = "What languages do they speak there?";
- String response2 = llm.send(question2);
- checkLlmMatch(response2, "Arabic", false);
+ // Resetting context should cause model to forget what country is being referred to
+ llm.resetContext();
- llm.freeModel();
+ String question2 = "What languages do they speak there?";
+ String response2 = llm.send(question2);
+ checkLlmMatch(response2, "Arabic", false);
+ llm.freeModel();
}
@Test
public void testInferenceWithoutContextReset() {
- LlmConfig llmConfig = new LlmConfig(modelTag, stopWords, modelPath, llmPrefix,userTag,numThreads);
+ LlmConfig llmConfig = new LlmConfig(modelTag, stopWords, modelPath, llmPrefix, userTag, endTag, numThreads);
Llm llm = new Llm();
llm.llmInit(llmConfig);
@@ -154,13 +155,12 @@ public class LlmTestJNI {
String question2 = "What languages do they speak there?";
String response2 = llm.send(question2);
checkLlmMatch(response2, "Arabic", true);
-
llm.freeModel();
}
@Test
public void testInferenceHandlesEmptyQuestion() {
- LlmConfig llmConfig = new LlmConfig(modelTag, stopWords, modelPath, llmPrefix,userTag,numThreads);
+ LlmConfig llmConfig = new LlmConfig(modelTag, stopWords, modelPath, llmPrefix, userTag, endTag, numThreads);
Llm llm = new Llm();
llm.llmInit(llmConfig);
@@ -177,14 +177,13 @@ public class LlmTestJNI {
String question2 = "What languages do they speak there?";
String response2 = llm.send(question2);
checkLlmMatch(response2, "Arabic", true);
-
llm.freeModel();
}
@Test
public void testMangoSubtractionLongConversation() {
- LlmConfig llmConfig = new LlmConfig(modelTag, stopWords, modelPath, llmPrefix,userTag,numThreads);
+ LlmConfig llmConfig = new LlmConfig(modelTag, stopWords, modelPath, llmPrefix, userTag, endTag, numThreads);
Llm llm = new Llm();
llm.llmInit(llmConfig);
@@ -195,8 +194,8 @@ public class LlmTestJNI {
// Set the initial ground truth in the conversation.
String initialContext = "There are " + originalMangoes + " mangoes in a basket.";
String initResponse = llm.send(initialContext);
- String originalQuery = "How many mangoes did we start with? ";
- String subtractQuery = "Subtract 1 mango. How many mangoes are left now? ";
+ String originalQuery = "How many mangoes did we start with?";
+ String subtractQuery = "Remove 1 mango from the basket. How many mangoes left in the basket now?";
// **Assert that the model acknowledges the initial count of mangoes.**
checkLlmMatch(initResponse, String.valueOf(originalMangoes), true);
@@ -204,6 +203,11 @@ public class LlmTestJNI {
// Loop to subtract 1 mango each iteration until reaching 0.
for (int i = 1; i < originalMangoes; i++) {
+ // Modify the query during the conversation
+ if (i == 2) {
+ subtractQuery = "Good, remove 1 mango again from the basket. How many mangoes left in the basket now?";
+ }
+
// Query to subtract one mango
String subtractionResponse = llm.send(subtractQuery);
mangoes -= 1; // Update our expected count
@@ -219,6 +223,7 @@ public class LlmTestJNI {
}
String postResetResponse = llm.send(originalQuery);
+
checkLlmMatch(postResetResponse, String.valueOf(originalMangoes), false);
llm.freeModel();
}
@@ -232,7 +237,7 @@ public class LlmTestJNI {
throw new RuntimeException("System properties for model_dir or config_file are not set!");
}
- LlmConfig llmConfig = new LlmConfig(modelTag, stopWords, modelPath, llmPrefix,userTag,numThreads);
+ LlmConfig llmConfig = new LlmConfig(modelTag, stopWords, modelPath, llmPrefix, userTag, endTag, numThreads);
Llm llm = new Llm();
llm.llmInit(llmConfig);
@@ -257,8 +262,6 @@ public class LlmTestJNI {
checkLlmMatch(response4, "Arabic", true);
checkLlmMatch(response4, "French", true);
-
- // Free model after use
llm.freeModel();
}
}