diff --git a/CMakeLists.txt b/CMakeLists.txt index 2e3afcb6641406dffa0139166c67776827d40428..000e87c23be336675e01688268d9eaa555c018b9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6,11 +6,17 @@ cmake_minimum_required(VERSION 3.27) +set(CMAKE_CXX_STANDARD 17) + # If CMake toolchain is not defined, set it here. if (NOT CMAKE_TOOLCHAIN_FILE) set(CMAKE_TOOLCHAIN_FILE "${CMAKE_CURRENT_SOURCE_DIR}/scripts/cmake/toolchains/native.cmake") endif() +set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) + message(STATUS "Using CMAKE_TOOLCHAIN_FILE: ${CMAKE_TOOLCHAIN_FILE}") # Declare project diff --git a/model_configuration_files/LLMUserConfig.txt b/model_configuration_files/LLMUserConfig.txt new file mode 100644 index 0000000000000000000000000000000000000000..d8db5ce1246b65c058e8ce2474f051f41ff04a6e --- /dev/null +++ b/model_configuration_files/LLMUserConfig.txt @@ -0,0 +1,2 @@ +numThreads=4 +batchSize=256 \ No newline at end of file diff --git a/model_configuration_files/llamaConfig.txt b/model_configuration_files/llamaConfig.txt index c71926162471520211c5fa6deece8219fc26ac47..457f2b60170b5fda6a6235077a7d99946dc0ec9d 100755 --- a/model_configuration_files/llamaConfig.txt +++ b/model_configuration_files/llamaConfig.txt @@ -1,3 +1,4 @@ -modelTagDefault=Orbita: -llmPrefixDefault=Transcript of a dialog, where the User interacts with an AI Assistant named Orbita. Orbita is helpful, polite, honest, good at writing and answers honestly with a maximum of two sentences. User: -stopWordsDefault=Orbita:,User:,AI:,<|user|>,Assistant:,user:,[end of text],<|endoftext|>,model:,Question:,"\n\n",Consider the following scenario:\n +modelTag=Orbita: +llmPrefix=Transcript of a dialog, where the User interacts with an AI Assistant named Orbita. Orbita is helpful, polite, honest, good at writing and answers honestly with a maximum of two sentences. User: +stopWords=Orbita:,User:,AI:,<|user|>,Assistant:,user:,[end of text],<|endoftext|>,model:,Question:,"\n\n",Consider the following scenario:\n +llmModelName=model.gguf \ No newline at end of file diff --git a/scripts/cmake/check-flag.cmake b/scripts/cmake/check-flag.cmake index c169fcd879ad28f6ca4925d15d0e73b975fd4657..f714a2e4d8dca95e0f332ac6ffc2602e17489662 100644 --- a/scripts/cmake/check-flag.cmake +++ b/scripts/cmake/check-flag.cmake @@ -1,5 +1,6 @@ # # SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +# # SPDX-License-Identifier: Apache-2.0 # include_guard(DIRECTORY) diff --git a/scripts/py/download_resources.py b/scripts/py/download_resources.py index af9b8b4609bfe4b7fc451f8565aa3319eaf97b12..22ccbc32ccb200257df7cda326de7922f4db6f6f 100644 --- a/scripts/py/download_resources.py +++ b/scripts/py/download_resources.py @@ -1,9 +1,8 @@ # -# SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates # # SPDX-License-Identifier: Apache-2.0 # - import json import hashlib from pathlib import Path @@ -12,7 +11,6 @@ import urllib.request import logging import sys from argparse import ArgumentParser - def download_file(url: str, dest: Path) -> None: """ Download a file @@ -69,7 +67,7 @@ def download_resources(resources_file: Path, download_dir: Path) -> None: logging.info(f'SHA256: {resource_data["sha256sum"]}') url = resource_data['url'] - dest = resource_dir / resource_data['destination'] + dest = resource_dir / resource_data['destination'] if dest.exists(): logging.info(f'{dest} exists; skipping download') @@ -103,9 +101,9 @@ if __name__ == "__main__": args = parser.parse_args() req_file = Path(args.requirements_file) - download_directory = Path(args.download_dir) + download_dir = Path(args.download_dir) if not req_file.exists(): raise FileNotFoundError(f'{req_file} does not exist') - download_resources(req_file, download_directory) + download_resources(req_file, download_dir) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 7b6d61290920168ce4fe35f057b20d04df60dfb1..e4d5fc4a8e3326563fe0f1880c36f0404353482f 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -3,7 +3,6 @@ # # SPDX-License-Identifier: Apache-2.0 # - add_subdirectory(cpp) if(BUILD_JNI_LIB) diff --git a/src/cpp/CMakeLists.txt b/src/cpp/CMakeLists.txt index caf71b61e3514ba3ea3003cc5f1f02c295b2690b..cdc34f6df1f82eed054913870b44e5b568bdf071 100644 --- a/src/cpp/CMakeLists.txt +++ b/src/cpp/CMakeLists.txt @@ -6,24 +6,40 @@ # Declare project project(arm-llm-cpp-prj - VERSION 0.0.1 - DESCRIPTION "An LLM CPP interface" + DESCRIPTION "An LLM CPP library" LANGUAGES C CXX ASM) -# Add the LLM API interface library -# NOTE: This is an interface library right now because our wrapper is -# only a header file. If we need to add source files later, this -# would change to a SHARED or STATIC library. -add_library(arm-llm-cpp INTERFACE) -target_include_directories(arm-llm-cpp INTERFACE - ${CMAKE_CURRENT_SOURCE_DIR}/include) - -if (${LLM_DEP_NAME} STREQUAL "llama.cpp") - add_subdirectory(llama_cpp) -#elseif(${LLM_DEP_NAME} STREQUAL "executorch") -# # TODO -#elseif(${LLM_DEP_NAME} STREQUAL "mediapipe") -# # TODO -else() - message(FATAL_ERROR "${LLM_DEP_NAME} is currently not supported :(") +# Add interface: +add_subdirectory(interface) + +# Add configuration lib: +add_subdirectory(config) + +# Add frameworks lib: +add_subdirectory(frameworks) + +# Actual API library that depends on framework +add_library(arm-llm-cpp STATIC + ${CMAKE_CURRENT_SOURCE_DIR}/Llm.cpp) + +# Libraries needed +target_link_libraries(arm-llm-cpp PUBLIC + arm-llm-interface + arm-llm-config + arm-llm-framework) + +# If building with JNI support, build a separate +# target to provide these bindings. +if (BUILD_JNI_LIB) + # Make sure JNI include directories have been set. + include(find-jni) + + add_library(arm-llm-jni SHARED + ${CMAKE_CURRENT_SOURCE_DIR}/LlmJni.cpp) + + target_link_libraries(arm-llm-jni PUBLIC arm-llm-cpp) + + target_include_directories(arm-llm-jni PUBLIC + ${JNI_INCLUDE_DIRS} # Populated by FindJNI CMake module + ) endif() diff --git a/src/cpp/Llm.cpp b/src/cpp/Llm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5e337bb2541bd040c2514c540bd45e01e5028667 --- /dev/null +++ b/src/cpp/Llm.cpp @@ -0,0 +1,71 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#include "LlmImpl.hpp" + +LLM::LLM() +{ + this->m_impl = std::make_unique(); +} + +LLM::~LLM() +{ + this->FreeLlm(); +} + +void LLM::LlmInit(const LlmConfig& llmConfig) +{ + this->m_impl->LlmInit(llmConfig); +} + +void LLM::FreeLlm() +{ + this->m_impl->FreeLlm(); +} + +float LLM::GetEncodeTimings() +{ + return this->m_impl->GetEncodeTimings(); +} + +float LLM::GetDecodeTimings() +{ + return this->m_impl->GetDecodeTimings(); +} + +void LLM::ResetTimings() +{ + this->m_impl->ResetTimings(); +} + +std::string LLM::SystemInfo() +{ + return this->m_impl->SystemInfo(); +} + +void LLM::ResetContext() +{ + this->m_impl->ResetContext(); +} + +void LLM::Encode(std::string text) +{ + this->m_impl->Encode(text); +} + +std::string LLM::NextToken() +{ + return this->m_impl->NextToken(); +} + +size_t LLM::GetChatProgress() +{ + return this->m_impl->GetChatProgress(); +} + +std::string LLM::BenchModel(int& nPrompts, int& nEvalPrompts, int& nMaxSeq, int& nRep) +{ + return this->m_impl->BenchModel(nPrompts, nEvalPrompts, nMaxSeq, nRep); +} diff --git a/src/cpp/LlmJni.cpp b/src/cpp/LlmJni.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4f595dd1533cf089a4b74f16c9ed05d38f9a59cd --- /dev/null +++ b/src/cpp/LlmJni.cpp @@ -0,0 +1,103 @@ +// +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#include "LlmConfig.hpp" +#include "LlmImpl.hpp" + +#include + +static std::unique_ptr llm{nullptr}; + +#ifdef __cplusplus +extern "C" { +#endif + +JNIEXPORT jlong JNICALL Java_com_arm_Llm_createLlmConfig(JNIEnv* env, + jobject /* this */, + jstring jModelTag, + jstring jModelPath, + jstring jLlmPrefix, + jint jNumThreads, + jint jBatchSize) +{ + const char* modelTag = env->GetStringUTFChars(jModelTag, nullptr); + const char* modelPath = env->GetStringUTFChars(jModelPath, nullptr); + const char* llmPrefix = env->GetStringUTFChars(jLlmPrefix, nullptr); + + auto* config = new LlmConfig(std::string(modelTag), + std::string(modelPath), + std::string(llmPrefix), + static_cast(jNumThreads), + static_cast(jBatchSize)); + + // Clean up + env->ReleaseStringUTFChars(jModelTag, modelTag); + env->ReleaseStringUTFChars(jModelPath, modelPath); + env->ReleaseStringUTFChars(jLlmPrefix, llmPrefix); + + return reinterpret_cast(config); // Return pointer as long +} + +JNIEXPORT jlong JNICALL Java_com_arm_Llm_loadModel(JNIEnv* env, jobject, jlong pconfig) +{ + auto config = reinterpret_cast(pconfig); + llm = std::make_unique(); + llm->LlmInit(*config); + return 0; +} + +JNIEXPORT void JNICALL Java_com_arm_Llm_freeLlm(JNIEnv*, jobject) +{ + llm->FreeLlm(); +} + +JNIEXPORT void JNICALL Java_com_arm_Llm_encode(JNIEnv* env, jobject, jstring jtext) +{ + const auto text = env->GetStringUTFChars(jtext, 0); + llm->Encode(text); + env->ReleaseStringUTFChars(jtext, text); +} + +JNIEXPORT jstring JNICALL Java_com_arm_Llm_getNextToken(JNIEnv* env, jobject) +{ + std::string result = llm->NextToken(); + return env->NewStringUTF(result.c_str()); +} + +JNIEXPORT jfloat JNICALL Java_com_arm_Llm_getEncodeRate(JNIEnv* env, jobject) +{ + float result = llm->GetEncodeTimings(); + return result; +} + +JNIEXPORT jfloat JNICALL Java_com_arm_Llm_getDecodeRate(JNIEnv* env, jobject) +{ + + float result = llm->GetDecodeTimings(); + return result; +} + +JNIEXPORT void JNICALL Java_com_arm_Llm_resetTimings(JNIEnv* env, jobject) +{ + llm->ResetTimings(); +} +JNIEXPORT jsize JNICALL Java_com_arm_Llm_getChatProgress(JNIEnv* env, jobject) +{ + return llm->GetChatProgress(); +} +JNIEXPORT void JNICALL Java_com_arm_Llm_resetContext(JNIEnv* env, jobject) +{ + llm->ResetContext(); +} +JNIEXPORT jstring JNICALL Java_com_arm_Llm_benchModel( + JNIEnv* env, jobject, jint nPrompts, jint nEvalPrompts, jint nMaxSeq, jint nRep) +{ + std::string result = llm->BenchModel(nPrompts, nEvalPrompts, nMaxSeq, nRep); + return env->NewStringUTF(result.c_str()); +} + +#ifdef __cplusplus +} +#endif diff --git a/src/cpp/config/CMakeLists.txt b/src/cpp/config/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..95bb4d03b9b2adec8066bc05881486dfc4f7c274 --- /dev/null +++ b/src/cpp/config/CMakeLists.txt @@ -0,0 +1,16 @@ +# +# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +# +# SPDX-License-Identifier: Apache-2.0 +# +project(arm-llm-config + VERSION 0.0.1 + DESCRIPTION "LLM configuration" + LANGUAGES CXX) + +# Add the LLM config library +add_library(arm-llm-config STATIC + ${CMAKE_CURRENT_SOURCE_DIR}/LlmConfig.cpp) + +target_include_directories(arm-llm-config PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR}) diff --git a/src/cpp/config/LlmConfig.cpp b/src/cpp/config/LlmConfig.cpp new file mode 100644 index 0000000000000000000000000000000000000000..db5f95fd8bf65bf9039ccc81258fa4a4c8ad9902 --- /dev/null +++ b/src/cpp/config/LlmConfig.cpp @@ -0,0 +1,74 @@ +// +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#include "LlmConfig.hpp" +#include + +LlmConfig::LlmConfig(const std::string& modelTag, + const std::string& modelPath, + const std::string& llmPrefix, + int numThreads, + int batchSize) : + m_modelTag(modelTag), m_modelPath(modelPath), m_llmPrefix(llmPrefix) +{ + SetNumThreads(numThreads); + SetBatchSize(batchSize); +} + +std::string LlmConfig::GetModelTag() const +{ + return this->m_modelTag; +} + +std::string LlmConfig::GetModelPath() const +{ + return this->m_modelPath; +} + +std::string LlmConfig::GetLlmPrefix() const +{ + return this->m_llmPrefix; +} + +int LlmConfig::GetNumThreads() const +{ + return this->m_numThreads; +} + +int LlmConfig::GetBatchSize() const +{ + return this->m_batchSize; +} + +void LlmConfig::SetModelTag(const std::string& modelIdentifier) +{ + this->m_modelTag = modelIdentifier; +} + +void LlmConfig::SetModelPath(const std::string& basePath) +{ + this->m_modelPath = basePath; +} + +void LlmConfig::SetLlmPrefix(const std::string& llmInitialPrompt) +{ + this->m_llmPrefix = llmInitialPrompt; +} + +void LlmConfig::SetNumThreads(int threads) +{ + if (threads <= 0) { + throw std::invalid_argument("number of threads must be a positive integer."); + } + this->m_numThreads = threads; +} + +void LlmConfig::SetBatchSize(int batchSz) +{ + if (batchSz <= 0) { + throw std::invalid_argument("batch-size must be a positive integer."); + } + this->m_batchSize = batchSz; +} diff --git a/src/cpp/config/LlmConfig.hpp b/src/cpp/config/LlmConfig.hpp new file mode 100644 index 0000000000000000000000000000000000000000..477d3faae774ff2b934c7d26b758da0029ea0fcb --- /dev/null +++ b/src/cpp/config/LlmConfig.hpp @@ -0,0 +1,91 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#ifndef LLM_CONFIG_HPP +#define LLM_CONFIG_HPP + +#include + +class LlmConfig { +private: + std::string m_modelTag{}; + std::string m_modelPath{}; + std::string m_llmPrefix{}; + int m_numThreads{}; + int m_batchSize{}; + +public: + LlmConfig(const std::string& modelTag, + const std::string& modelPath, + const std::string& llmPrefix, + int numThreads, + int batchSize); + + LlmConfig() = default; + + /** + * Returns the model tag string (The name to appear in conversation with the LLM). + * @return modelTag + */ + std::string GetModelTag() const; + + /** + * Returns the path to the model file. + * @return modelPath + */ + std::string GetModelPath() const; + + /** + * Returns the LLM prompt prefix string. + * @return llmPrefix + */ + std::string GetLlmPrefix() const; + + /** + * Returns the number of threads configured for inference. + * @return number of Threads + */ + int GetNumThreads() const; + + /** + * Returns the batch size used for querying. + * @return batch size + */ + int GetBatchSize() const; + + /** + * Sets the model tag (The name to appear in conversation with the LLM).. + * @param modelIdentifier is the tag name added at the end of each user question to make model + * respond appropriately + */ + void SetModelTag(const std::string& modelIdentifier); + + /** + * Sets the file path to the model. + * @param basePath absolute path to load llm model + */ + void SetModelPath(const std::string& basePath); + + /** + * Method sets the prompt prefix used for LLM inputs. + * @param llmInitialPrompt LLM's need to prompt engineered to respond intelligently. + * Provide an engineered initial-prompt here. + */ + void SetLlmPrefix(const std::string& llmInitialPrompt); + + /** + Sets the number of threads to use for LLM model inference + @param threads number of threads used inference of model + */ + void SetNumThreads(int threads); + + /** + Sets the batch size for inference. Throws std::invalid_argument if the value is not positive. + @param batchSz chunk-size of each batch used to split query-encoding + */ + void SetBatchSize(int batchSz); +}; + +#endif /* LLM_CONFIG_HPP */ diff --git a/src/cpp/config/README.md b/src/cpp/config/README.md new file mode 100644 index 0000000000000000000000000000000000000000..954b960aa2499220d4a206f06c87a4b7e483e9a0 --- /dev/null +++ b/src/cpp/config/README.md @@ -0,0 +1,8 @@ + + +This space is for all configuration related sources. If there are dependencies +added later, it should only show up in the root CMake file here. diff --git a/src/cpp/frameworks/CMakeLists.txt b/src/cpp/frameworks/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..84166829cfa06fc156fc22ecff2e4e179d5464de --- /dev/null +++ b/src/cpp/frameworks/CMakeLists.txt @@ -0,0 +1,12 @@ +# +# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +# +# SPDX-License-Identifier: Apache-2.0 +# + +# Pull in LLM framework library +if (${LLM_DEP_NAME} STREQUAL "llama.cpp") + add_subdirectory(llama_cpp) +else() + message(FATAL_ERROR "${LLM_DEP_NAME} is currently not supported :(") +endif() diff --git a/src/cpp/frameworks/README.md b/src/cpp/frameworks/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d676725a2d07976d046eca9956d2c47f9a0b0c0e --- /dev/null +++ b/src/cpp/frameworks/README.md @@ -0,0 +1,7 @@ + +Frameworks directory contains different backends we can choose from +to provide implementation logic for our interface. diff --git a/src/cpp/frameworks/llama_cpp/CMakeLists.txt b/src/cpp/frameworks/llama_cpp/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..08eb37d9eaafc3e2b76c45abea6d213ff96bf389 --- /dev/null +++ b/src/cpp/frameworks/llama_cpp/CMakeLists.txt @@ -0,0 +1,79 @@ +# +# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +# +# SPDX-License-Identifier: Apache-2.0 +# + +# Declare project +project(llama-cpp-wrapper + DESCRIPTION "llama.cpp wrapper interface implementation" + LANGUAGES C CXX ASM) + +include(FetchContent) + +# Where should llama.cpp sources be cloned into? +# It might make sense to download sources into resources as well and not +# every time into the CMake binary directory. However, because we currently +# need to patch it, and possibly do it conditionally based on target type +# we have this arrangement for the time being. +set(LLAMA_SRC_DIR "${CMAKE_BINARY_DIR}/llama.cpp" + CACHE PATH + "Path where llama.cpp repo should be cloned into") + +set(LLAMA_GIT_URL "https://github.com/ggerganov/llama.cpp.git" + CACHE STRING + "Git URL for llama.cpp repo") + +set(LLAMA_GIT_SHA "a4090d1" + CACHE STRING + "Git commit SHA for llama.cpp repo") + +set(LLAMA_BUILD_EXAMPLES ${BUILD_EXECUTABLE} CACHE BOOL "Build llama.cpp examples") + +# If the user has NOT explicitly set GGML_CPU_KLEIDIAI +if (NOT DEFINED GGML_CPU_KLEIDIAI) + # if we are on arm64/aarch64, then default KleidiAI to ON. + if (CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm64|ARM64)$") + set(GGML_CPU_KLEIDIAI ON CACHE BOOL + "Enable KleidiAI by default on ${CMAKE_SYSTEM_PROCESSOR}") + message(STATUS "KleidiAI enabled by default") + # if we are NOT on arm64/aarch64, then default KleidiAI to OFF. + else() + set(GGML_CPU_KLEIDIAI OFF CACHE BOOL + "Disable KleidiAI by default on ${CMAKE_SYSTEM_PROCESSOR}") + message(STATUS "KleidiAI disabled by default") + endif() +else () + message(STATUS "KleidiAI: ${GGML_CPU_KLEIDIAI}") +endif() + +set(GGML_BUILD_NUMBER 1) # We do a shallow clone `--depth=1` + +# Fetch the dependency Git repo here +FetchContent_Declare(llama-cpp + GIT_REPOSITORY ${LLAMA_GIT_URL} + GIT_TAG ${LLAMA_GIT_SHA} + GIT_SHALLOW ${GGML_BUILD_NUMBER} + SOURCE_DIR ${LLAMA_SRC_DIR}) + +FetchContent_MakeAvailable(llama-cpp) + +# Check internally defined dependency targets are visible +if (NOT TARGET arm-llm-config) + message(FATAL_ERROR "arm-llm-config target not defined") +elseif (NOT TARGET arm-llm-interface) + message(FATAL_ERROR "arm-llm-interface target not defined") +endif() + +add_library(arm-llm-framework STATIC + ${CMAKE_CURRENT_SOURCE_DIR}/LlamaImpl.cpp) + +target_include_directories(arm-llm-framework PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR}) + +# List all libraries that we need to depend on here: +target_link_libraries(arm-llm-framework PUBLIC + llama + common + arm-llm-config + arm-llm-interface) diff --git a/src/cpp/frameworks/llama_cpp/LlamaImpl.cpp b/src/cpp/frameworks/llama_cpp/LlamaImpl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e5d9e21e8b4e4053b7532f4d87c1df4a9c071038 --- /dev/null +++ b/src/cpp/frameworks/llama_cpp/LlamaImpl.cpp @@ -0,0 +1,437 @@ +// +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#include "LlmImpl.hpp" + +#define LOG_INF(...) \ + do { \ + fprintf(stdout, __VA_ARGS__); \ + } while (0) + +static bool is_valid_utf8(const char* string); + +/** + * @brief LLama Implementation of our LLM API + * + */ +LLM::LLMImpl::LLMImpl() = default; + +LLM::LLMImpl::~LLMImpl() +{ + this->FreeLlm(); +} + +void LLM::LLMImpl::LoadModel(const char* pathToModel) +{ + const llama_model_params model_params = llama_model_default_params(); + this->m_llmModel = llama_model_load_from_file(pathToModel, model_params); + if (this->m_llmModel == nullptr) { + throw std::runtime_error("error: unable to load model from " + std::string(pathToModel)); + } +} + +void LLM::LLMImpl::FreeModel() +{ + if (this->m_llmModel) { + llama_model_free(this->m_llmModel); + this->m_llmModel = nullptr; + } +} + +void LLM::LLMImpl::NewContext(int numThreads) +{ + llama_context_params ctx_params = llama_context_default_params(); + ctx_params.n_ctx = this->m_nCtx; + ctx_params.n_threads = numThreads; + ctx_params.n_threads_batch = numThreads; + ctx_params.no_perf = false; + this->m_llmContext = llama_init_from_model(this->m_llmModel, ctx_params); + if (this->m_llmContext == nullptr) { + throw std::runtime_error("NewContext failed: Unable to create llama context"); + } +} + +void LLM::LLMImpl::FreeContext() +{ + if (this->m_llmContext) { + llama_free(this->m_llmContext); + this->m_llmContext = nullptr; + } +} + +void LLM::LLMImpl::LlmInit(const LlmConfig& config) +{ + try { + this->m_batchSz = config.GetBatchSize(); + + LoadModel(config.GetModelPath().c_str()); + BackendInit(); + + this->m_llmPrefix = config.GetLlmPrefix(); + + if (this->m_llmModel != nullptr) { + NewContext(config.GetNumThreads()); + } + NewSampler(); + this->m_llmInitialized = true; + } catch (const std::exception& e) { + throw std::runtime_error("Llama initialization failed: " + std::string(e.what()) + "/n"); + } +} + +void LLM::LLMImpl::FreeLlm() +{ + if (this->m_llmInitialized) { + FreeContext(); + FreeModel(); + BackendFree(); + this->m_nCur = 0; + FreeSampler(); + this->m_llmInitialized = false; + } +} + +void LLM::LLMImpl::BackendInit() +{ + llama_backend_init(); +} + +void LLM::LLMImpl::BackendFree() +{ + llama_backend_free(); +} + +void LLM::LLMImpl::FreeBatch() +{ + llama_batch_free(this->m_llmBatch); +} + +void LLM::LLMImpl::FreeSampler() +{ + llama_sampler_free(this->m_pLlmSampler); +} + +float LLM::LLMImpl::GetEncodeTimings() +{ + const auto resultsTiming = llama_perf_context(this->m_llmContext); + return static_cast(1e3 / resultsTiming.t_p_eval_ms * resultsTiming.n_p_eval); +} + +float LLM::LLMImpl::GetDecodeTimings() +{ + const auto resultsTiming = llama_perf_context(this->m_llmContext); + return static_cast(1e3 / resultsTiming.t_eval_ms * resultsTiming.n_eval); +} + +void LLM::LLMImpl::ResetTimings() +{ + llama_perf_context_reset(this->m_llmContext); +} + +std::string LLM::LLMImpl::SystemInfo() +{ + return std::string(llama_print_system_info()); +} + +void LLM::LLMImpl::KVCacheClear() +{ + llama_kv_self_clear(this->m_llmContext); +} + +void LLM::LLMImpl::KVCacheSeqRm(int32_t p0, int p1) +{ + // setting sequence ID to negative to match any sequence + int seqId = -1; + llama_kv_self_seq_rm(this->m_llmContext, seqId, p0, p1); +} + +int32_t LLM::LLMImpl::GetInitialPromptLength(const char* text, + int32_t textLength, + bool addSpecial, + bool parseSpecial) +{ + const llama_vocab* vocab = llama_model_get_vocab(this->m_llmModel); + const auto tokens = static_cast(malloc(sizeof(llama_token) * this->m_nCtx)); + return llama_tokenize(vocab, text, textLength, tokens, this->m_nCtx, addSpecial, parseSpecial); +} + +void LLM::LLMImpl::ResetContext() +{ + if (!this->m_llmPrefix.empty()) { + auto n_prefix = GetInitialPromptLength( + this->m_llmPrefix.c_str(), this->m_llmPrefix.length(), true, false); + + KVCacheSeqRm(n_prefix, -1); + this->m_nCur = n_prefix; + } else { + KVCacheClear(); + this->m_nCur = 0; + } +} + +llama_batch LLM::LLMImpl::NewBatch(int numTokens, int embeddings, int numSequenceMax) +{ + return llama_batch_init(numTokens, embeddings, numSequenceMax); +} + +void LLM::LLMImpl::NewSampler() +{ + auto sparams = llama_sampler_chain_default_params(); + sparams.no_perf = false; + this->m_pLlmSampler = llama_sampler_chain_init(sparams); + llama_sampler_chain_add(this->m_pLlmSampler, llama_sampler_init_greedy()); +} + +void LLM::LLMImpl::Encode(std::string& prompt) +{ + const auto prompt_tokens = common_tokenize(this->m_llmContext, prompt, 1); + + size_t promptLength = prompt_tokens.size(); + + // check prompt size + if (promptLength > this->m_nCtx - 4) { + fprintf(stderr, "%s: error: unable to Encode large prompt \n", __func__); + } else if (promptLength + this->m_nCur > this->m_nCtx - 4) { + fprintf(stdout, "%s: warning: unable to Encode prompt context full \n", __func__); + } else if (promptLength <= 1) { + fprintf(stderr, "%s: error: unable to Encode empty prompt \n", __func__); + } else { + for (size_t idx = 0; idx < promptLength; idx += this->m_batchSz) { + const size_t end_idx = std::min(idx + this->m_batchSz, promptLength - 1); + const bool lastBatch = (end_idx == (promptLength - 1)); + auto sub_prompt = std::vector(prompt_tokens.begin() + idx, + prompt_tokens.begin() + end_idx + 1); + if (!sub_prompt.empty()) { + CompletionInit(sub_prompt, lastBatch); + } + } + } +} + +void LLM::LLMImpl::CompletionInit(llama_tokens sub_tokens_list, bool lastBatch) +{ + // Synchronize llama to remove idle time between function calls + llama_synchronize(this->m_llmContext); + llama_batch batch = NewBatch(this->m_batchSz, 0, 1); + common_batch_clear(batch); + // evaluate the initial prompt + for (auto i = this->m_nCur; i < sub_tokens_list.size() + this->m_nCur; i++) { + common_batch_add(batch, sub_tokens_list[i - this->m_nCur], i, {0}, false); + } + + // llama_decode will output logits only for the last token of the prompt + if (lastBatch) { + batch.logits[batch.n_tokens - 1] = true; + } + + if (llama_decode(this->m_llmContext, batch) != 0) { + LOG_INF("llama_decode() failed"); + return; + } + + llama_synchronize(this->m_llmContext); + this->m_nCur += batch.n_tokens; +} + +std::string LLM::LLMImpl::CompletionLoop() +{ + const auto model = + llama_get_model(this->m_llmContext); // CHANGE FROM JOBJECT TO PASSING ACTUAL CONTEXT + + const llama_vocab* vocab = llama_model_get_vocab(model); + + const auto new_token_id = llama_sampler_sample(this->m_pLlmSampler, this->m_llmContext, -1); + + if ((llama_vocab_eos(vocab) == new_token_id) || (this->m_nCur == this->m_nCtx)) { + return "<|endoftext|>"; + } + + auto new_token_chars = common_token_to_piece(this->m_llmContext, new_token_id); + this->m_cachedTokenChars += new_token_chars; + std::string new_token = ""; + if (is_valid_utf8(this->m_cachedTokenChars.c_str())) { + new_token = this->m_cachedTokenChars.c_str(); + this->m_cachedTokenChars.clear(); + } else { + new_token = ""; + } + llama_batch batch = NewBatch(this->m_batchSz, 0, 1); + common_batch_clear(batch); + common_batch_add(batch, new_token_id, this->m_nCur, {0}, true); + + if (llama_decode(this->m_llmContext, batch) != 0) { + LOG_INF("llama_decode() failed"); + } + + // Synchronize llama to remove idle time between function calls + llama_synchronize(this->m_llmContext); + ++this->m_nCur; + return new_token; +} + +std::string LLM::LLMImpl::NextToken() +{ + std::string result = CompletionLoop(); + if ((result == "<|endoftext|>") && (this->m_nCur >= this->m_nCtx)) { + this->m_contextFilled = 100; + return "ctx_full"; + } else { + this->m_contextFilled = 100 * this->m_nCur / this->m_nCtx; + } + return result; +} + +size_t LLM::LLMImpl::GetChatProgress() const +{ + return this->m_contextFilled; +} + +std::string LLM::LLMImpl::BenchModel(int& prompts, int& eval_prompts, int& n_max_sq, int& n_rep) +{ + auto prompts_avg = 0.0; + auto eval_prompts_avg = 0.0; + auto prompts_std = 0.0; + auto eval_prompts_std = 0.0; + + LOG_INF("m_nCtx = %d", this->m_nCtx); + + int i; + for (int nri = 0; nri < n_rep; nri++) { + LOG_INF("Benchmark prompt processing (pp)\n"); + + common_batch_clear(this->m_llmBatch); + + const int n_tokens = prompts; + for (i = 0; i < n_tokens; i++) { + common_batch_add(this->m_llmBatch, 0, i, {0}, false); + } + + this->m_llmBatch.logits[this->m_llmBatch.n_tokens - 1] = true; + llama_kv_self_clear(this->m_llmContext); + + const auto t_prompts_start = ggml_time_us(); + if (llama_decode(this->m_llmContext, this->m_llmBatch) != 0) { + LOG_INF("llama_decode() failed during prompt processing\n"); + } + const auto t_prompts_end = ggml_time_us(); + + // bench text generation + + LOG_INF("Benchmark text generation (tg)\n"); + + llama_kv_self_clear(this->m_llmContext); + const auto t_eval_prompts_start = ggml_time_us(); + for (i = 0; i < eval_prompts; i++) { + common_batch_clear(this->m_llmBatch); + for (int j = 0; j < n_max_sq; j++) { + common_batch_add(this->m_llmBatch, 0, i, {j}, true); + } + + LOG_INF("llama_decode() text generation: %d\n", i); + if (llama_decode(this->m_llmContext, this->m_llmBatch) != 0) { + LOG_INF("llama_decode() failed during text generation \n"); + } + } + + const auto t_eval_prompts_end = ggml_time_us(); + + llama_kv_self_clear(this->m_llmContext); + + const auto t_prompts = static_cast(t_prompts_end - t_prompts_start) / 1000000.0; + const auto t_eval_prompts = static_cast(t_eval_prompts_end - t_eval_prompts_start) / 1000000.0; + + const auto speed_prompts = static_cast(prompts) / t_prompts; + const auto speed_eval_prompts = static_cast(n_max_sq * eval_prompts) / t_eval_prompts; + + prompts_avg += speed_prompts; + eval_prompts_avg += speed_eval_prompts; + + prompts_std += speed_prompts * speed_prompts; + eval_prompts_std += speed_eval_prompts * speed_eval_prompts; + + LOG_INF("prompt eval %f t/s, token generation %f t/s\n", speed_prompts, speed_eval_prompts); + } + + prompts_avg /= static_cast(n_rep); + eval_prompts_avg /= static_cast(n_rep); + + if (n_rep > 1) { + prompts_std = sqrt(prompts_std / static_cast(n_rep - 1) - + prompts_avg * prompts_avg * static_cast(n_rep) / static_cast(n_rep - 1)); + eval_prompts_std = + sqrt(eval_prompts_std / static_cast(n_rep - 1) - + eval_prompts_avg * eval_prompts_avg * static_cast(n_rep) / static_cast(n_rep - 1)); + } else { + prompts_std = 0; + eval_prompts_std = 0; + } + + char model_desc[128]; + llama_model_desc(this->m_llmModel, model_desc, sizeof(model_desc)); + + const auto model_size = static_cast(llama_model_size(this->m_llmModel)) / 1024.0 / 1024.0 / 1024.0; + const auto model_n_params = static_cast(llama_model_n_params(this->m_llmModel)) / 1e9; + + const auto backend = "cpu"; // TODO: What should this be? + + std::stringstream result; + result << "| model | size | params | backend | test | t/s |\n"; + result << "| --- | --- | --- | --- | --- | --- |\n"; + result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " + << backend << " | prompts " << prompts << " | " << prompts_avg << " ± " << prompts_std + << " |\n"; + result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " + << backend << " | tg " << eval_prompts << " | " << eval_prompts_avg << " ± " + << eval_prompts_std << " |\n"; + + return result.str().c_str(); +} + +/** + * @brief Checks if a given string is valid UTF-8. + * + * This function validates whether the input C-string adheres to the UTF-8 encoding standard. + * It iterates through each byte of the string, determining the expected length of UTF-8 sequences + * based on leading byte patterns, and verifies that subsequent bytes match the UTF-8 format. + * + * @param string Pointer to a null-terminated C-string to be validated. + * @return true if the string is valid UTF-8 or if the input is a null pointer; false otherwise. + */ +static bool is_valid_utf8(const char* string) +{ + if (!string) { + return true; + } + + auto bytes = reinterpret_cast(string); + int num; + + while (*bytes != 0x00) { + if ((*bytes & 0x80) == 0x00) { + // U+0000 to U+007F + num = 1; + } else if ((*bytes & 0xE0) == 0xC0) { + // U+0080 to U+07FF + num = 2; + } else if ((*bytes & 0xF0) == 0xE0) { + // U+0800 to U+FFFF + num = 3; + } else if ((*bytes & 0xF8) == 0xF0) { + // U+10000 to U+10FFFF + num = 4; + } else { + return false; + } + + bytes += 1; + for (int i = 1; i < num; ++i) { + if ((*bytes & 0xC0) != 0x80) { + return false; + } + bytes += 1; + } + } + return true; +} diff --git a/src/cpp/frameworks/llama_cpp/LlmImpl.hpp b/src/cpp/frameworks/llama_cpp/LlmImpl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1d3f027c4654e699e736f8840ec438101412c2aa --- /dev/null +++ b/src/cpp/frameworks/llama_cpp/LlmImpl.hpp @@ -0,0 +1,227 @@ +// +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#ifndef LLM_IMPL_HPP +#define LLM_IMPL_HPP + +#include "Llm.hpp" +#include "LlmConfig.hpp" + +#include "common.h" +#include "llama.h" + +#include +#include + +/* Forward declaration */ +class LLM; + +/** + * @brief LLama Implementation of our LLM API + */ +class LLM::LLMImpl { + +public: + LLMImpl(); + ~LLMImpl(); + + /** + * Method to Initialize a llama_model + * @param config Configuration class with model's parameter and user defined parameters + */ + void LlmInit(const LlmConfig& config); + + /** + * Method to Free all allocations pertaining to llama model + */ + void FreeLlm(); + + /** + * Function to retrieve the llama encode timings. + * @return The encoded tokens per second + */ + float GetEncodeTimings(); + + /** + * Function to retrieve the llama decode timings. + * @return The decoded tokens per second + */ + float GetDecodeTimings(); + + /** + * Function to reset the llama timing + */ + void ResetTimings(); + + /** + * Function to print the system info + * @return System info as a char pointer + */ + std::string SystemInfo(); + + /** + * Method to reset Conversation history and preserve Model's character prefix. + * If model's prefix is not defined all conversation history would be cleared + */ + void ResetContext(); + + /** + * Method to wrap CompletionInit function with batching and length fitness + * @param prompt Query to LLM + */ + void Encode(std::string& prompt); + + /** + * Method to wrap CompletionLoop function + * @return the next Token for Encoded Prompt + */ + std::string NextToken(); + + /** + * @brief The Method return the percentage of chat context filled + * @return chat capacity filled in cache as percentage number + */ + size_t GetChatProgress() const; + + /** + * @brief Benchmarks the performance of the LLM model. + * + * This function evaluates the model's performance by processing a specified number of prompts + * and generating text sequences. It measures the speed of prompt evaluation and text + * generation, calculates average speeds and standard deviations over multiple repetitions, and + * compiles the results into a formatted string. + * + * @param prompts Number of prompts to process during benchmarking. + * @param eval_prompts Number of evaluation prompts for text generation. + * @param n_max_sq Maximum sequence length for text generation. + * @param n_rep Number of repetitions for benchmarking to obtain average metrics. + * @return A formatted string containing the benchmark results, including model description, + * size, number of parameters, backend information, and performance metrics for prompt + * evaluation and text generation. + */ + + std::string BenchModel(int& prompts, int& eval_prompts, int& n_max_sq, int& n_rep); + +private: + llama_context* m_llmContext{nullptr}; + llama_model* m_llmModel{nullptr}; + llama_batch m_llmBatch{}; + llama_sampler* m_pLlmSampler{nullptr}; + size_t m_batchSz{0}; + int m_nCtx{2048}; + std::string m_cachedTokenChars{""}; + size_t m_contextFilled{0}; + std::string m_llmPrefix{""}; + bool m_llmInitialized{false}; + size_t m_nCur{0}; + + /** + * Function to load the chosen llama model to memory + * @param pathToModel path to the model location + * @return llama_model or null-pointer if no model is found + */ + void LoadModel(const char* pathToModel); + + /** + * Function to create a new llama_context object in memory + * @param numThreads number of threads to set in the context + */ + void NewContext(int numThreads); + + /** + * Frees the memory holding the llama_model + */ + void FreeModel(); + + /** + * Free up the memory that is storing the llama_context + */ + void FreeContext(); + + /** + * Function to initialize the llama backend + */ + void BackendInit(); + + /** + * Function to free up the memory storing the backend + */ + void BackendFree(); + + /** + * Function to free up the memory storing the Batch instance + */ + void FreeBatch(); + + /** + * Function to free Sampler + */ + void FreeSampler(); + + /** + * Function to clear KV Cache and hence all conversation history + */ + void KVCacheClear(); + + /** + * Function to removes all tokens that belong to the last sequence(-1) and have positions in + * [p0, p1) + * @param p0 + * @param p1 + */ + void KVCacheSeqRm(int32_t p0, int p1); + + /** + * Function to tokenize the initial prompt + * @param text + * @param textLength + * @param addSpecial + * @param parseSpecial + * @return length of original prompt + */ + + int32_t GetInitialPromptLength(const char* text, + int32_t textLength, + bool addSpecial, + bool parseSpecial); + + /** + * Function to initialize batch object + * @param numTokens + * @param embeddings + * @param numSequenceMax + * @return batch object + */ + llama_batch NewBatch(int numTokens, int embeddings, int numSequenceMax); + + /** + * Function to Create a new sampler object + * @return Initialised sampler object + */ + + void NewSampler(); + + /**Taken from llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp and slightly + * modified + * @param sub_tokens_list a vector of tokens to encode into llama model + * @param lastBatch whether the current batch is last set of tokens in given query. + */ + void CompletionInit(llama_tokens sub_tokens_list, bool lastBatch); + + /** + * @brief Generates a token completion for the given context and batch. + * + * This function processes the current context and batch to generate the next token in the + *sequence. It utilizes the model's vocabulary and sampling methods to produce a token, which is + *then converted to a string representation. The function also handles end-of-sequence tokens + *and ensures UTF-8 validity of the generated token. + *. + * @return The generated token as a string. Returns "<|endoftext|>" if the end-of-sequence token + *is produced or if the current length reaches the maximum length. + */ + std::string CompletionLoop(); +}; + +#endif /* LLM_IMPL_HPP */ diff --git a/src/cpp/include/LLM.hpp b/src/cpp/include/LLM.hpp deleted file mode 100644 index a96e793be86ff4c96ca7ea9ac29fa39b7ac56a9f..0000000000000000000000000000000000000000 --- a/src/cpp/include/LLM.hpp +++ /dev/null @@ -1,239 +0,0 @@ -// -// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates -// -// SPDX-License-Identifier: Apache-2.0 -// - -#ifndef LLM_LLM_HPP -#define LLM_LLM_HPP - -/** - * Interface to LLM - * Contains generic LLM functions - * Should invoke whatever implementation this is initialised with - */ - -template -class LLM -{ -private: - -public: - - /** - * Stores partial or accumulated token characters between inference steps - */ - std::string cached_token_chars; - - /** - * Function to load the chosen LLM model to memory - * @tparam P model type - * @param path_to_model path to the model location - * @return LLM model pointer - */ - template - P *LoadModel(const char *path_to_model) - { - return ((T *) this)->template LoadModel

(path_to_model); - } - - /** - * Frees the memory holding the LLM model - * @tparam P model type - * @param model LLM model pointer - */ - template - void FreeModel(P *model) - { - ((T *) this)->template FreeModel

(model); - } - - /** - * Function to create a new LLM context object in memory - * @tparam P LLM context type - * @tparam V LLM model type - * @param model LLM model instance - * @param numThreads number of threads - * @return context pointer - */ - template - P *NewContext(V *model, int numThreads) - { - return ((T *) this)->template NewContext(model, numThreads); - } - - /** - * Free up the memory that is storing the LLM context - * @tparam P LLM context type - * @param context LLM context pointer - */ - template - void FreeContext(P *context) - { - ((T *) this)->template FreeModel

(context); - } - - /** - * Function to initialize the LLM backend - */ - void BackendInit() - { - ((T *) this)->BackendInit(); - } - - /** - * Function to free up the memory storing the backend - */ - void BackendFree() - { - ((T *) this)->BackendFree(); - } - - /** - * Function to free up the memory storing the batch instance - * @tparam P LLM batch type - * @param batch LLM Batch object pointer - */ - template - void FreeBatch(P &batch) - { - ((T *) this)->template FreeBatch

(batch); - } - - /** - * Function to retrieve the LLM encode timings - * @tparam P LLM context type - * @param context LLM Context object pointer - * @return encode timings - */ - template - float GetEncodeTimings(P *context) - { - return ((T *) this)->template GetEncodeTimings

(context); - } - - /** - * Function to retrieve the LLM decode timings - * @tparam P LLM context type - * @param context LLM Context object pointer - * @return decode timings - */ - template - float GetDecodeTimings(P *context) - { - return ((T *) this)->template GetDecodeTimings

(context); - } - - /** - * Function to reset the LLM timings - * @tparam P LLM context type - * @param context LLM context object pointer - */ - template - void ResetTimings(P *context) - { - ((T *) this)->template ResetTimings

(context); - } - - /** - * Function to print the system info - * @return system information - */ - const char *SystemInfo() - { - return ((T *) this)->SystemInfo(); - } - - /** - * Function to perform KV Cache clear - * @tparam P LLM context type - * @param context LLM Context object pointer - */ - template - void KVCacheClear(P *context) - { - ((T *) this)->template KVCacheClear

(context); - } - - /** - * Function to removes all tokens that belong to the specified sequence and have positions in [p0, p1) - * @tparam P LLM context object type - * @param context LLM Context object - * @param p0 starting token index (inclusive) - * @param p1 ending token index (exclusive) - */ - template - void KVCacheSeqRm(P *context, int p0, int p1) - { - ((T *) this)->template KVCacheSeqRm

(context, p0, p1); - } - - /** - * Function to tokenize the initial prompt - * @tparam P LLM model type - * @tparam V LLM tokens container - * @param model pointer to the LLM model - * @param text prompt text to be tokenized - * @param textLength length of the prompt text in bytes (or characters) - * @param tokens pointer to the container/array that will hold the resulting tokens - * @param maxNumTokens maximum number of tokens - * @param addSpecial if `true`, includes special tokens (e.g., BOS/EOS) in the output - * @param parseSpecial if `true`, parses special tokens directly from the prompt text - * @return length of original prompt - */ - template - int GetInitialPromptLength(P *model, const char *text, int32_t textLength, V *tokens, - int32_t maxNumTokens, bool addSpecial, bool parseSpecial) - { - return ((T *) this)->template GetInitialPromptLength(model, text, textLength, - tokens, maxNumTokens, addSpecial, parseSpecial); - } - - /** - * Function to create a new batch object - * @tparam P LLM batch type - * @param embeddings embedding dimension for each token - * @param numTokens maximum number of tokens in the batch - * @param numSequenceMax maximum number of sequences or contexts - * @return newly created batch object - */ - template - P NewBatch(int embeddings, int numTokens, int numSequenceMax) - { - return ((T *) this)->template NewBatch

(embeddings, numTokens, numSequenceMax); - } - - /** - * Function used for encoding the prompt text - * @tparam P LLM Context Type - * @tparam V LLM Batch Type - * @param text input prompt text to be encoded - * @param context pointer to the LLM context - * @param batch pointer to the LLM batch - * @param startPos start position of the text that will be used - * @return number of tokens if successful otherwise error code - */ - template - int CompletionInit(std::string text, P *context, V *batch, int startPos) - { - return ((T *) this)->template CompletionInit(text, context, batch, startPos); - } - - /** - * Main inference loop, returns each token of response from llm - * @tparam P LLM context type - * @tparam V LLM batch type - * @param context LLM context object pointer - * @param batch LLM batch object pointer - * @param nCur reference to the current token index in the sequence - * @param nLen reference to the total number of tokens to generate - * @return newly generated token as a string - */ - template - std::string CompletionLoop(P *context, V *batch, int &nCur, int &nLen) - { - return ((T *) this)->template CompletionLoop(context, batch, nCur, nLen); - } -}; - -#endif //LLM_LLM_HPP diff --git a/src/cpp/interface/CMakeLists.txt b/src/cpp/interface/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..bfc8c1ee41b0fe921fd175702a427bc4833d5834 --- /dev/null +++ b/src/cpp/interface/CMakeLists.txt @@ -0,0 +1,14 @@ +# +# SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates +# +# SPDX-License-Identifier: Apache-2.0 +# + +# Declare project +project(arm-llm-interface + DESCRIPTION "An LLM CPP interface") + +# Add the LLM API interface library: +add_library(arm-llm-interface INTERFACE) +target_include_directories(arm-llm-interface INTERFACE + ${CMAKE_CURRENT_SOURCE_DIR}) diff --git a/src/cpp/interface/Llm.hpp b/src/cpp/interface/Llm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..03cd4f4c96c18b047946eb41ab0669c3cb491227 --- /dev/null +++ b/src/cpp/interface/Llm.hpp @@ -0,0 +1,92 @@ +// +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#ifndef ARM_LLM_HPP +#define ARM_LLM_HPP + +#include "LlmConfig.hpp" +#include + +class LLM { +private: + class LLMImpl; + std::unique_ptr m_impl{}; + +public: + LLM(); /**< Constructor */ + ~LLM(); /**< Destructor */ + + /** + * Method to Initialize a llama_model + * @param llmConfig Configuration class with model's parameter and user defined parameters + */ + void LlmInit(const LlmConfig& llmConfig); + + /** + * Method to Free Model + */ + void FreeLlm(); + + /** + * Function to retrieve the llm encode timings + * @return encode timings + */ + float GetEncodeTimings(); + + /** + * Function to retrieve the llm decode timings + * @return decode timings + */ + float GetDecodeTimings(); + + /** + * Function to reset the llm timings + */ + void ResetTimings(); + + /** + * Function to print the system info + * @return System info as a char pointer + */ + std::string SystemInfo(); + + /** + * Method to reset Conversation history and preserve Model's character prefix. + * If model's prefix is not defined all conversation history would be cleared + */ + void ResetContext(); + + /** + * Function to Encode Query into the llm. Use NextToken to get subsequent tokens. + * @param text THe query to be encoded + */ + void Encode(std::string text); + + /** + * Function to get response from llm as token by token. Call Encode before + * @return result single token + */ + std::string NextToken(); + + /** + * Function to get percentage of Context capacity filled in model's cache + * @return percentage of context filled + */ + size_t GetChatProgress(); + + /** + * Function to bench the underlying llm backend + * @param nPrompts prompt length used for benchmarking + * @param nEvalPrompts number of generated tokens for benchmarking + * @param nMaxSeq sequence number + * @param nRep number of repetitions + * @return the results of benchmarking in string format for prompt generation and evaluation + */ + + std::string BenchModel(int& nPrompts, int& nEvalPrompts, int& nMaxSeq, int& nRep); +}; + +#endif /* ARM_LLM_HPP */ diff --git a/src/cpp/llama_cpp/CMakeLists.txt b/src/cpp/llama_cpp/CMakeLists.txt deleted file mode 100644 index 918be765c02e72b860938ddc213a05ea119c0f19..0000000000000000000000000000000000000000 --- a/src/cpp/llama_cpp/CMakeLists.txt +++ /dev/null @@ -1,125 +0,0 @@ -# -# SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates -# -# SPDX-License-Identifier: Apache-2.0 -# - -# Declare project -project(llama-cpp-wrapper - VERSION 0.0.1 - DESCRIPTION "llama.cpp wrapper interface implementation" - LANGUAGES C CXX ASM) - -include(FetchContent) - -# Where should llama.cpp sources be cloned into? -# It might make sense to download sources into resources as well and not -# every time into the CMake binary directory. However, because we currently -# need to patch it, and possibly do it conditionally based on target type -# we have this arrangement for the time being. -set(LLAMA_SRC_DIR "${CMAKE_BINARY_DIR}/llama.cpp" - CACHE PATH - "Path where llama.cpp repo should be cloned into") - -set(LLAMA_GIT_URL "https://github.com/ggerganov/llama.cpp.git" - CACHE STRING - "Git URL for llama.cpp repo") - -set(LLAMA_GIT_SHA "a4090d1" - CACHE STRING - "Git commit SHA for llama.cpp repo") - -set(LLAMA_BUILD_EXAMPLES ${BUILD_EXECUTABLE} CACHE BOOL "Build llama.cpp examples") - -# Identify system processor for Arm -string(TOLOWER "${CMAKE_SYSTEM_PROCESSOR}" lower_system_processor) -if (lower_system_processor MATCHES "^(aarch64|arm64)$") - set(ARM_SYSTEM_PROCESSOR ON BOOL) -endif() - -# If the user has NOT explicitly set GGML_CPU_KLEIDIAI -if (NOT DEFINED GGML_CPU_KLEIDIAI) - # if we are on arm64/aarch64, then default KleidiAI to ON. - if (ARM_SYSTEM_PROCESSOR) - set(GGML_CPU_KLEIDIAI ON CACHE BOOL - "Enable KleidiAI by default on ${CMAKE_SYSTEM_PROCESSOR}") - message(STATUS "KleidiAI enabled by default") - # if we are NOT on arm64/aarch64, then default KleidiAI to OFF. - else() - set(GGML_CPU_KLEIDIAI OFF CACHE BOOL - "Disable KleidiAI by default on ${CMAKE_SYSTEM_PROCESSOR}") - message(STATUS "KleidiAI disabled by default") - endif() -else () - message(STATUS "KleidiAI: ${GGML_CPU_KLEIDIAI}") -endif() - -# Avoid enabling KleidiAI for native builds -if(GGML_CPU_KLEIDIAI AND NOT ARM_SYSTEM_PROCESSOR) - message(FATAL_ERROR "KleidiAI enabled, but not supported on ${CMAKE_SYSTEM_PROCESSOR}") -endif() - -# ----------------------------------------- -# NOTE: Current work flow for i8mm implementation: -# If user provides any -march flag, check if compiler supports it. -# ----------------------------------------- - -# Include functions for flag checking -include(check-flag) - -# Define the list of supported languages. -set(SUPPORTED_LANGUAGES C CXX) - -# Check and set C,CXX flags -foreach(LANG ${SUPPORTED_LANGUAGES}) - # Check for -march flag in C flags. - string(REGEX MATCH "-march=[^ ]+" USER_${LANG}_MARCH_FLAG "${CMAKE_${LANG}_FLAGS}") - if (USER_${LANG}_MARCH_FLAG) - message(STATUS "User provided -march flag in ${LANG} flags: ${USER_${LANG}_MARCH_FLAG}") - # Check if the compiler supports the provided C,CXX -march flag - check_compiler_support("${LANG}" "${USER_${LANG}_MARCH_FLAG}") - endif() -endforeach() - -# Fetch the dependency Git repo here -FetchContent_Declare(llama-cpp - GIT_REPOSITORY ${LLAMA_GIT_URL} - GIT_TAG ${LLAMA_GIT_SHA} - GIT_SHALLOW 1 # We only need shallow clone with `--depth=1` - SOURCE_DIR ${LLAMA_SRC_DIR} - ) - -FetchContent_MakeAvailable(llama-cpp) - -if (NOT TARGET arm-llm-cpp) - add_library(arm-llm-cpp INTERFACE) - target_include_directories(arm-llm-cpp INTERFACE - ${CMAKE_CURRENT_SOURCE_DIR}/../include) -endif () - -# Add the current -target_include_directories(arm-llm-cpp INTERFACE - ${CMAKE_CURRENT_SOURCE_DIR}/include) - -# Do not turn incompatible pointer types warning into an error -# Temporary fix for ggml-cpu target for the revision of llama.cpp used currently -target_compile_options(ggml-cpu PRIVATE -Wno-error=incompatible-pointer-types) - -# List all libraries that we need to depend on here: -target_link_libraries(arm-llm-cpp INTERFACE - llama - common -) - -if (BUILD_JNI_LIB) - # Make sure JNI include directories have been set. - include(find-jni) - - add_library(arm-llm-jni SHARED - ${CMAKE_CURRENT_SOURCE_DIR}/jni/Llama.cpp) - - target_link_libraries(arm-llm-jni PUBLIC arm-llm-cpp) - target_include_directories(arm-llm-jni PUBLIC - ${JNI_INCLUDE_DIRS} # Populated by FindJNI CMake module - ) -endif () diff --git a/src/cpp/llama_cpp/include/LlamaImpl.hpp b/src/cpp/llama_cpp/include/LlamaImpl.hpp deleted file mode 100644 index 0982c1ddc6b3b999e24859dea318ec0a89d59a3a..0000000000000000000000000000000000000000 --- a/src/cpp/llama_cpp/include/LlamaImpl.hpp +++ /dev/null @@ -1,405 +0,0 @@ -// -// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates -// -// SPDX-License-Identifier: Apache-2.0 -// - -# pragma once - -#include -#include -#include "llama.h" -#include "common.h" -#include "LLM.hpp" - -#define LOG_INF(...) do { fprintf(stdout, __VA_ARGS__); } while (0) - -/** -* @brief LLama Implementation of our LLM API -* -*/ -class LlamaImpl : public LLM -{ -private: - -public: - LlamaImpl() = default; - - /** - * Function to load the chosen llama model to memory - * @tparam P llama_model - * @param path_to_model path to the model location - * @return llama_model or null pointer if no model is found - */ - template - P *LoadModel(const char *path_to_model) - { - const llama_model_params model_params = llama_model_default_params(); - auto model = llama_model_load_from_file(path_to_model, model_params); - if (model == nullptr) { - fprintf(stderr , "%s: error: unable to load model\n" , __func__); - return model; - } - - return model; - } - - /** - * Free the memory holding the llama_model - * @tparam P llama_model - * @param model the pointer to the llama_model - */ - template - void FreeModel(P *model) - { - llama_model_free(model); - } - - /** - * Function to create a new llama_context object in memory - * @tparam P llama_context - * @tparam V llama_model - * @param model LLM model pointer - * @param numThreads number of threads to set in the context - * @return LLM context object pointer - */ - template - P *NewContext(V *model, const int numThreads) - { - llama_context_params ctx_params = llama_context_default_params(); - ctx_params.n_ctx = 2048; - ctx_params.n_threads = numThreads; - ctx_params.n_threads_batch = numThreads; - ctx_params.no_perf = false; - - llama_context *context = llama_init_from_model(model, ctx_params); - - return context; - } - - /** - * Free up the memory that is storing the llama_context - * @tparam P llama_context - * @param llamaContext LLM context pointer - */ - template - void FreeContext(P *llamaContext) - { - llama_free(llamaContext); - } - - /** - * Function to initialize the llama backend - */ - void BackendInit() - { - llama_backend_init(); - } - - /** - * Function to free up the memory storing the backend - */ - void BackendFree() - { - llama_backend_free(); - } - - /** - * Function to free up the memory storing the Batch instance - * @tparam P llama_batch - * @param batch LLM Batch object pointer - */ - template - void FreeBatch(P &batch) - { - llama_batch_free(batch); - } - - /** - * Function to retrieve the llama encode timings - * @tparam P llama_context - * @param context LLM Context object pointer - * @return The encoded tokens per second - */ - template - float GetEncodeTimings(P *context) - { - auto resultsTiming = llama_perf_context(context); - return (1e3 / resultsTiming.t_p_eval_ms * resultsTiming.n_p_eval); - } - - /** - * Function to retrieve the llama decode timings - * @tparam P llama_context - * @param context LLM Context object pointer - * @return The decoded tokens per second - */ - template - float GetDecodeTimings(P *context) - { - auto resultsTiming = llama_perf_context(context); - return (1e3 / resultsTiming.t_eval_ms * resultsTiming.n_eval); - } - /** - * Function to reset the llama timings - * @tparam P llama_context - * @param context LLM context object pointer - */ - template - void ResetTimings(P *context) - { - llama_perf_context_reset(context); - } - - /** - * Function to print the system info - * @return System info as a char pointer - */ - const char *SystemInfo() - { - return llama_print_system_info(); - } - - /** - * Function to clear KV Cache - * @tparam P llama_context - * @param context LLM Context object pointer - */ - template - void KVCacheClear(P *context) - { - llama_kv_cache_clear(context); - } - - /** - * Function to remove all tokens that belong to the specified sequence and have positions in [p0, p1) - * @tparam P llama_context - * @param context LLM Context object - * @param p0 starting token index (inclusive) - * @param p1 ending token index (exclusive) - */ - template - void KVCacheSeqRm(P *context, int p0, int p1) - { - llama_kv_cache_seq_rm(context, -1, p0, p1); - } - - /** - * Function to tokenize the initial prompt - * @tparam P llama_model - * @tparam V llama tokens container - * @param model pointer to the LLM model - * @param text prompt text to be tokenized - * @param textLength length of the prompt text in bytes (or characters) - * @param tokens pointer to the container/array that will hold the resulting tokens - * @param maxNumTokens maximum number of tokens - * @param addSpecial if `true`, includes special tokens (e.g., BOS/EOS) in the output - * @param parseSpecial if `true`, parses special tokens directly from the prompt text - * @return length of original prompt - */ - template - int GetInitialPromptLength(P *model, const char *text, int32_t textLength, V *tokens, - int32_t maxNumTokens, bool addSpecial, bool parseSpecial) - { - const llama_vocab * vocab = llama_model_get_vocab(model); - return llama_tokenize(vocab, - text, - textLength, - tokens, - maxNumTokens, - addSpecial, - parseSpecial); - } - - /** - * Function to Create a new batch object - * @tparam P llama_batch - * @param embeddings embedding dimension for each token - * @param numTokens maximum number of tokens in the batch - * @param numSequenceMax maximum number of sequences or contexts - * @return newly created batch object - */ - template - P NewBatch(int numTokens, int embeddings, int numSequenceMax) - { - llama_batch batch = llama_batch_init(numTokens, embeddings, numSequenceMax); - return batch; - } - - /** - * Function to Create a new sampler object - * @param p_llama_sampler - * @return Initialised sampler object - */ - - llama_sampler *NewSampler(llama_sampler *p_llama_sampler) - { - auto sampler_params = llama_sampler_chain_default_params(); - sampler_params.no_perf = false; - - p_llama_sampler = llama_sampler_chain_init(sampler_params); - llama_sampler_chain_add(p_llama_sampler, llama_sampler_init_greedy()); - return p_llama_sampler; - } - - /** Encode the prompt text, inspired by llama.cpp Android example - * Use 0 for startPos else nCur - * @tparam P llama context - * @tparam V llama batch - * @param text input prompt text to be encoded - * @param context pointer to the LLM context - * @param batch pointer to the LLM batch - * @param startPos start position of the text that will be used - * @return number of tokens if successful otherwise error code - */ - template - int CompletionInit(std::string &text, P *context, V *batch, int &startPos) - { - //Synchronize llama to remove idle time between function calls - llama_synchronize(context); - - const auto tokens_list = common_tokenize(context, text, 1); - common_batch_clear(*batch); - // evaluate the initial prompt - for (auto i = startPos; i < tokens_list.size() + startPos; i++) - { - common_batch_add(*batch, tokens_list[i - startPos], i, {0}, false); - } - - // llama_decode will output logits only for the last token of the prompt - batch->logits[batch->n_tokens - 1] = true; - if (llama_decode(context, *batch) != 0) - { - LOG_INF("llama_decode() failed"); - return 1; - } - llama_synchronize(context); - return batch->n_tokens; - } - - /** - * @brief Generates a token completion for the given context and batch - * - * This function processes the current context and batch to generate the next token in the sequence. - * It utilizes the model's vocabulary and sampling methods to produce a token, which is then converted - * to a string representation. The function also handles end-of-sequence tokens and ensures UTF-8 validity - * of the generated token - * - * @tparam P Type representing the context, typically `llama_context` - * @tparam V Type representing the batch, typically `llama_batch` - * @param context pointer to the LLM context object - * @param batch pointer to the LLM batch object - * @param nCur reference to the current length of the sequence - * @param nLen reference to the maximum length of the sequence - * @return generated token as a string. Returns "<|endoftext|>" if the end-of-sequence token is produced - * or if the current length reaches the maximum length - */ - template - std::string CompletionLoop(P *context, V *batch, int &nCur, int &nLen) - { - const auto model = llama_get_model(context); - std::string test; - - const llama_vocab * vocab = llama_model_get_vocab(model); - auto n_vocab = llama_vocab_n_tokens(vocab); - - auto logits = llama_get_logits_ith(context, batch->n_tokens - 1); - - std::vector candidates; - candidates.reserve(n_vocab); - for (llama_token token_id = 0; token_id < n_vocab; token_id++) - { - candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); - } - - llama_sampler *sampler_pointer = NewSampler(sampler_pointer); - - const auto new_token_id = llama_sampler_sample(sampler_pointer, context, -1); - - if ((llama_vocab_eos(vocab) == new_token_id) || (nCur == nLen)) - { - return "<|endoftext|>"; - } - - auto new_token_chars = common_token_to_piece(context, new_token_id); - cached_token_chars += new_token_chars; - std::string new_token; - if (is_valid_utf8(cached_token_chars.c_str())) - { - new_token = cached_token_chars; - cached_token_chars.clear(); - } else - { - new_token = ""; - } - common_batch_clear(*batch); - common_batch_add(*batch, new_token_id, nCur, {0}, true); - - if (llama_decode(context, *batch) != 0) - { - LOG_INF("llama_decode() failed"); - } - - //Synchronize llama to remove idle time between function calls - llama_synchronize(context); - - ++nCur; - return new_token; - } - - /** - * @brief Checks if a given string is valid UTF-8 - * - * This function validates whether the input C-string adheres to the UTF-8 encoding standard. - * It iterates through each byte of the string, determining the expected length of UTF-8 sequences - * based on leading byte patterns, and verifies that subsequent bytes match the UTF-8 format - * - * @param string Pointer to a null-terminated C-string to be validated - * @return true if the string is valid UTF-8 or if the input is a null pointer; false otherwise - */ - bool is_valid_utf8(const char *string) - { - if (!string) - { - return true; - } - - const auto *bytes = reinterpret_cast(string); - int num; - - while (*bytes != 0x00) - { - if ((*bytes & 0x80) == 0x00) - { - // U+0000 to U+007F - num = 1; - } else if ((*bytes & 0xE0) == 0xC0) - { - // U+0080 to U+07FF - num = 2; - } else if ((*bytes & 0xF0) == 0xE0) - { - // U+0800 to U+FFFF - num = 3; - } else if ((*bytes & 0xF8) == 0xF0) - { - // U+10000 to U+10FFFF - num = 4; - } else - { - return false; - } - - bytes += 1; - for (int i = 1; i < num; ++i) - { - if ((*bytes & 0xC0) != 0x80) - { - return false; - } - bytes += 1; - } - } - return true; - } -}; diff --git a/src/cpp/llama_cpp/jni/Llama.cpp b/src/cpp/llama_cpp/jni/Llama.cpp deleted file mode 100644 index 6f8c98071c067adbb45caa2fca00d3b0259f6a53..0000000000000000000000000000000000000000 --- a/src/cpp/llama_cpp/jni/Llama.cpp +++ /dev/null @@ -1,286 +0,0 @@ -// -// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates -// -// SPDX-License-Identifier: Apache-2.0 -// - -#include -#include -#include "LlamaImpl.hpp" -#include "LLM.hpp" -#include -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif - -/** - * Load llama model from path - * @param env JNI environment - * @param path_to_model path to llama model - * @return pointer to llama model - */ -JNIEXPORT jlong JNICALL Java_com_arm_llm_Llama_loadModel(JNIEnv *env, jobject, jstring path_to_model) -{ - - const char *path = env->GetStringUTFChars(path_to_model, nullptr); - if (path == nullptr || strlen(path) == 0) - { - env->ReleaseStringUTFChars(path_to_model, path); - return 0; - } - - auto *llm = new LLM(); - auto *model = llm->LoadModel(path); - env->ReleaseStringUTFChars(path_to_model, path); - return reinterpret_cast(model); -} - -/** - * Perform KV cache clear - * @param contextPtr pointer to the model context - */ -JNIEXPORT void JNICALL -Java_com_arm_llm_Llama_kvCacheClear(JNIEnv, jobject, jlong contextPtr) -{ - llama_kv_self_clear(reinterpret_cast(contextPtr)); -} - -/** - * Remove all tokens that belong to the specified sequence - * @param contextPtr pointer to the model context - * @param start_pos starting position of sequence - * @param last_pos last position of sequence - */ -JNIEXPORT void JNICALL -Java_com_arm_llm_Llama_kvCacheSeqRm(JNIEnv, jobject, jlong contextPtr, jint start_pos, jint last_pos) -{ - llama_kv_self_seq_rm(reinterpret_cast(contextPtr), -1, start_pos, last_pos); -} - -/** - * Computes the token length of the given prompt text using - * @param env JNI environment - * @param model_ptr pointer to llama model - * @param text_length length of the prompt text - * @param jtext string containing the prompt text to be tokenized - * @param add_special if true, includes special tokens - * @return the number of tokens generated from the text - */ -JNIEXPORT jint JNICALL -Java_com_arm_llm_Llama_getInitialPromptLength(JNIEnv *env, jobject, jlong model_ptr, jint text_length, jstring jtext, - jboolean add_special) -{ - auto *model = reinterpret_cast(model_ptr); - - const auto text = env->GetStringUTFChars(jtext, nullptr); - bool parse_special = false; - int max_num_tokens = 1024; - auto tokens = static_cast(malloc(sizeof(llama_token) * max_num_tokens)); - const llama_vocab * vocab = llama_model_get_vocab(model); - return llama_tokenize(vocab,text,text_length,tokens,max_num_tokens,add_special,parse_special); - -} - -/** - * Frees the model, associated context, and llama backend resources - * @param model pointer to llama model - * @param contextPtr pointer to the model context - */ -JNIEXPORT void JNICALL -Java_com_arm_llm_Llama_freeModel(JNIEnv *, jobject, jlong model, jlong contextPtr) -{ - llama_free(reinterpret_cast(contextPtr)); - llama_model_free(reinterpret_cast(model)); - llama_backend_free(); -} - -/** - * Initialize llama backend - */ -JNIEXPORT void JNICALL Java_com_arm_llm_Llama_backendInit(JNIEnv, jobject) -{ - auto *llm = new LLM(); - llm->BackendInit(); -} - -/** - * Creates and initializes a new sampler chain with default parameters - * @return pointer to the newly created sampler chain - */ -JNIEXPORT jlong JNICALL -Java_com_arm_llm_Llama_newSampler(JNIEnv *, jobject) -{ - auto sampler_params = llama_sampler_chain_default_params(); - llama_sampler *smpl = llama_sampler_chain_init(sampler_params); - llama_sampler_chain_add(smpl, llama_sampler_init_greedy()); - llama_sampler_chain_add(smpl, llama_sampler_init_temp(0.0)); - return reinterpret_cast(smpl); -} - -/** - * Frees a sampler previously created new sampler - */ -JNIEXPORT void JNICALL -Java_com_arm_llm_Llama_freeSampler(JNIEnv *, jobject, jlong sampler_pointer) -{ - llama_sampler_free(reinterpret_cast(sampler_pointer)); -} - -/** - * Creates a new llama batch - * @param n_tokens total number of tokens - * @param embd embedding dimension - * @param n_seq_max The maximum number of sequences - * @return pointer to the newly allocated llama batch - */ -JNIEXPORT jlong JNICALL Java_com_arm_llm_Llama_newBatch(JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) -{ - - auto *batch = new llama_batch{ - 0, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - }; - - if (embd) - { - batch->embd = static_cast(malloc(sizeof(float) * n_tokens * embd)); - } else - { - batch->token = static_cast(malloc(sizeof(llama_token) * n_tokens)); - } - - batch->pos = static_cast(malloc(sizeof(llama_pos) * n_tokens)); - batch->n_seq_id = static_cast(malloc(sizeof(int32_t) * n_tokens)); - batch->seq_id = static_cast(malloc(sizeof(llama_seq_id *) * n_tokens)); - for (int i = 0; i < n_tokens; ++i) - { - batch->seq_id[i] = static_cast(malloc(sizeof(llama_seq_id) * n_seq_max)); - } - batch->logits = static_cast(malloc(sizeof(int8_t) * n_tokens)); - - return reinterpret_cast(batch); -} - -/** - * Create a new context - * @param env JNI environment - * @param model_ptr llama model pointer - * @param numThreads number of threads - * @return pointer to the new model context - */ -JNIEXPORT jlong JNICALL Java_com_arm_llm_Llama_newContext(JNIEnv *env, jobject, jlong model_ptr, jint numThreads) -{ - - auto model = reinterpret_cast(model_ptr); - if (!model) - { - env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"), "Model cannot be null"); - return 0; - } - auto *llm = new LLM(); - auto *context = llm->NewContext(model, numThreads); - - if (!context) - { - env->ThrowNew(env->FindClass("java/lang/IllegalStateException"), - "llama_new_context_with_model() returned null)"); - return 0; - } - - return reinterpret_cast(context); -} - -/** - * Function used for encoding the prompt text - * @param env JNI environment - * @param jtext text containing the prompt to encode - * @param contextPtr pointer to model context - * @param batch_pointer pointer to current batch - * @param start_pos starting position in the text to be encoded - * @return number of tokens - */ -JNIEXPORT jint JNICALL -Java_com_arm_llm_Llama_completionInit(JNIEnv *env, jobject, jstring jtext, jlong contextPtr, jlong batch_pointer, \ -jint start_pos) -{ - const auto text = env->GetStringUTFChars(jtext, nullptr); - const auto context = reinterpret_cast(contextPtr); - const auto batch = reinterpret_cast(batch_pointer); - auto *llm = new LLM(); - int n_tokens = llm->CompletionInit(text, context, batch, start_pos); - env->ReleaseStringUTFChars(jtext, text); - - return n_tokens; -} - -/** - * This function processes the current context and batch to generate the next token in the sequence. - * It utilizes the model's vocabulary and sampling methods to produce a token, which is then converted - * to a string representation. The function also handles end-of-sequence tokens and ensures UTF-8 validity - * of the generated token - * - * @param env JNI environment - * @param contextPtr pointer to model context - * @param batchPtr pointer to current batch - * @param nCur current sequence length - * @param nLen max sequence length - * @return the next token - */ -JNIEXPORT jstring JNICALL -Java_com_arm_llm_Llama_completionLoop(JNIEnv *env, jobject, jlong contextPtr, jlong batchPtr, jint nCur, jint nLen) -{ - auto *context = reinterpret_cast(contextPtr); - auto *batch = reinterpret_cast(batchPtr); - auto *llm = new LLM(); - std::string result = llm->CompletionLoop(context, batch, nCur, nLen); - return env->NewStringUTF(result.c_str()); -} - -/** - * Get llama encode timings - * @param contextPtr pointer to the model context - * @return encode timings - */ -JNIEXPORT jfloat JNICALL -Java_com_arm_llm_Llama_getEncodeTimings(JNIEnv, jobject, jlong contextPtr) -{ - auto *context = reinterpret_cast(contextPtr); - auto *llm = new LLM(); - float result = llm->GetEncodeTimings(context); - return result; -} - -/** - * Get llama decode timings - * @param contextPtr pointer to the model context - * @return decode timings - */ -JNIEXPORT jfloat JNICALL -Java_com_arm_llm_Llama_getDecodeTimings(JNIEnv, jobject, jlong contextPtr) -{ - - auto resultsTiming = llama_perf_context(reinterpret_cast(contextPtr)); - return static_cast(1e3 / resultsTiming.t_eval_ms * resultsTiming.n_eval); -} - -/** - * Reset timings recorded previously - * @param contextPtr pointer to the model context - */ -JNIEXPORT void JNICALL -Java_com_arm_llm_Llama_resetTimings(JNIEnv, jobject, jlong contextPtr) -{ - llama_perf_context_reset(reinterpret_cast(contextPtr)); -} - -#ifdef __cplusplus -} -#endif diff --git a/src/java/CMakeLists.txt b/src/java/CMakeLists.txt index 6c1f2a389f46e03b56e95d68779998ddd5861c84..500d61e4ed911df8033b234363d57d597910658c 100644 --- a/src/java/CMakeLists.txt +++ b/src/java/CMakeLists.txt @@ -13,15 +13,8 @@ project(arm-llm-java-prj # Add the Java LLM API interface library. This is never going to be built # but we can use its property to get the interface source used in other # targets. + add_library(arm-llm-java INTERFACE) -if (${LLM_DEP_NAME} STREQUAL "llama.cpp") - target_sources(arm-llm-java INTERFACE com/arm/llm/Llama.java com/arm/llm/LlamaConfig.java) - add_dependencies(arm-llm-java arm-llm-jni) -#elseif(${LLM_DEP_NAME} STREQUAL "executorch") -# # TODO -#elseif(${LLM_DEP_NAME} STREQUAL "mediapipe") -# # TODO -else() - message(FATAL_ERROR "${LLM_DEP_NAME} is currently not supported :(") -endif() +target_sources(arm-llm-java INTERFACE com/arm/Llm.java com/arm/LlmConfig.java) +add_dependencies(arm-llm-java arm-llm-jni) diff --git a/src/java/com/arm/Llm.java b/src/java/com/arm/Llm.java new file mode 100644 index 0000000000000000000000000000000000000000..889f8794cef3c6bd5e65dab559b60f7376aab639 --- /dev/null +++ b/src/java/com/arm/Llm.java @@ -0,0 +1,269 @@ +// +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +package com.arm; + +import java.util.List; +import java.util.concurrent.Flow; +import java.util.concurrent.SubmissionPublisher; +import java.util.concurrent.atomic.AtomicBoolean; + +public class Llm extends SubmissionPublisher +{ + static + { + try + { + System.loadLibrary("arm-llm-jni"); + } catch (UnsatisfiedLinkError e) + { + System.err.println("Llama: Failed to load library: arm-llm-jni"); + e.printStackTrace(); + } + } + + private long llmPtr = 0; + private String modelTag = ""; + private String userTag = ""; + private List stopWords = null; + private String cachedToken = ""; + private String emitToken = ""; + private String llmPrefix = ""; + private AtomicBoolean evaluatedOnce = new AtomicBoolean(false); + //#ToDo move to LlmConfig + private int numThreads = 4; + private int batchSize = 256; + + // Native method declarations + /** + Method to create LlmConfig cpp instance from params. + @param modelTag name used to refer the Model + @param modelPath path to load model from + @param llmPrefix Initial-prompt to load into llm before query + @param numThreads Number of threads for inference + @param batchSize batch size used to chunk queries + */ + public native long createLlmConfig(String modelTag, String modelPath, String llmPrefix, + int numThreads, int batchSize); + /** + * Method for loading LLM model + @param pathToModel file path for loading model + @return pointer to loaded model + */ + public native long loadModel(long LlmConfig); + + /** + * Method for freeing LLM model + * @param modelPtr to free model + */ + private native void freeLlm(); + + /** + * Public method for getting encode timing + * @return timings in tokens/s for encoding prompt + */ + public native float getEncodeRate(); + + /** + * Public method for getting decode timing + * @return timings in tokens/s for decoding prompt + */ + public native float getDecodeRate(); + + /** + * Private method for resetting conversation history + */ + public native void resetContext(); + + /** + * Method for resetting timing information + */ + public native void resetTimings(); + + /** + * Method to encode the given text + * @param text the prompt to be encoded + */ + private native void encode(String text); + + /** + * Method to get Next Token once encoding is done. + * This Method needs to be called in a loop while monitoring for Stop-Words. + * @return next Token as String + */ + private native String getNextToken(); + + /** + * Method to get chat Progress in percentage + */ + public native int getChatProgress(); + + /** + * Method to decode answers one by one, once prefill stage is completed + * + * @param nPrompts prompt length used for benchmarking + * @param nEvalPrompts number of generated tokens for benchmarking + * @param nMaxSeq sequence number + * @param nRep number of repetitions + */ + public native String benchModel( + int nPrompts, + int nEvalPrompts, + int nMaxSeq, + int nRep + ); + + /** + *Method to separate Initialization from constructor + *@param llmConfig type configuration file to load model + */ + public void llmInit(LlmConfig llmConfig) + { + this.stopWords = llmConfig.getStopWords(); + this.modelTag = llmConfig.getModelTag(); + this.userTag = llmConfig.getUserTag(); + this.llmPrefix = llmConfig.getLlmPrefix(); + this.numThreads = llmConfig.getNumThreads(); + long configPtr = createLlmConfig(this.modelTag,llmConfig.getModelPath(), + this.llmPrefix,this.numThreads,this.batchSize); + this.llmPtr = loadModel(configPtr); + } + public void setSubscriber(Flow.Subscriber subscriber) + { + System.out.println("subscribed set from llama"); + this.subscribe(subscriber); + } + + /** + * Method to get response of a query asynchronously + * @param Query the prompt asked + */ + public void sendAsync(String Query) + { + String query = ""; + AtomicBoolean stop = new AtomicBoolean(false); + if (evaluatedOnce.get()) + query = userTag + Query + modelTag; + else + query = llmPrefix + Query + modelTag; + encode(query); + evaluatedOnce.set(true); + while (getChatProgress()<100) + { + String s = getNextToken(); + stop.set(inspectWord(s)); + if (stop.get()) + { + // needed for showing end of stream, Closing publisher will result in error + // for next query + emitToken = ""; + this.submit(emitToken); + + break; + } + this.submit(emitToken); + } + } + + /** + * Method to get response of a query synchronously + * @param Query the prompt asked + * @return response of LLM + */ + public String send(String Query) + { + String response = ""; + String query = ""; + boolean stop = false; + if (evaluatedOnce.get()) + query = userTag + Query + modelTag; + else + query = llmPrefix + Query + modelTag; + encode(query); + evaluatedOnce.set(true); + while (getChatProgress()<100) + { + String s = getNextToken(); + stop = inspectWord(s); + response += emitToken; + if (stop) + break; + } + return response; + } + + /** + * Method to find any stop-Words or partial stop-Word present in current token + * @param str current token decoded + * @return boolean for detection of stop word + */ + private boolean inspectWord(String str) + { + boolean stopWordTriggered = false; + String evaluationString = this.cachedToken + str; + // if stopWord is in evaluationString break loop. + for (String word : stopWords) + { + //use position to access inclusion of Stop-words. Preserve the substring before Stop word. + int position = evaluationString.indexOf(word); + if(position!=-1) + { + stopWordTriggered = true; + emitToken = evaluationString.substring(0,position); + cachedToken = ""; + return stopWordTriggered; + } + } + emitToken = evaluationString; + for (String word : stopWords) + { + String partialWord = word; + partialWord = partialWord.substring(0, partialWord.length() - 1); + while (!partialWord.isEmpty()) + { + // if the beginning for stop word coincides with end of emitted token don't emit current token. + if (evaluationString.endsWith(partialWord)) + { + emitToken = ""; + break; + } else + { + partialWord = partialWord.substring(0, partialWord.length() - 1); + } + } + } + this.cachedToken = emitToken.isEmpty() ? evaluationString : ""; + return stopWordTriggered; + } + + /** + * Sets the LLM prefix used for query processing. + * @param llmPrefix initial prompt for llm + */ + public void setLlmPrefix(String llmPrefix) + { + this.llmPrefix = llmPrefix; + } + + /** + * Sets the LLM ModelTag + */ + public void setLlmModelTag(String newTag) + { + this.modelTag = newTag; + } + /** + * Method to free model from memory + */ + public void freeModel() + { + freeLlm(); + this.close(); + evaluatedOnce.set(false); + } + + +} diff --git a/src/java/com/arm/LlmConfig.java b/src/java/com/arm/LlmConfig.java new file mode 100644 index 0000000000000000000000000000000000000000..f70c11eec3656d629791fd351360f99b53e24659 --- /dev/null +++ b/src/java/com/arm/LlmConfig.java @@ -0,0 +1,163 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +package com.arm; + +import java.util.List; + +public class LlmConfig +{ + private String modelTag; + private String userTag; + private String modelPath; + private String llmPrefix; + private List stopWords; + private int numThreads; + // minimal constructor without userTag and numThreads + public LlmConfig(String modelTag, List stopWords, String modelPath, String llmPrefix) + { + this(modelTag, stopWords, modelPath, llmPrefix, "", 4); + } + // minimal constructor without numThreads + public LlmConfig(String modelTag, List stopWords, String modelPath, String llmPrefix, String userTag) + { + // Use 4 threads by default + this(modelTag, stopWords, modelPath, llmPrefix, userTag, 4); + } + // minimal constructor without userTag + public LlmConfig(String modelTag, List stopWords, String modelPath, String llmPrefix,int numThreads) + { + this(modelTag, stopWords, modelPath, llmPrefix, "", numThreads); + } + // main constructor + public LlmConfig(String modelTag, List stopWords, String modelPath, + String llmPrefix, String userTag, int numThreads) + { + this.modelTag = modelTag; + this.stopWords = stopWords; + this.modelPath = modelPath; + this.llmPrefix = llmPrefix; + this.userTag = userTag; + this.numThreads = numThreads; + } + + /** + * Gets the model tag. + * + * @return The model tag. + */ + public String getModelTag() + { + return this.modelTag; + } + /** + * Gets the user tag. + * + * @return The user tag. + */ + public String getUserTag() + { + return this.userTag; + } + /** + * Gets the list of stop words. + * + * @return The list of stop words. + */ + public List getStopWords() + { + return this.stopWords; + } + + /** + * Gets the model path. + * + * @return The model path. + */ + public String getModelPath() + { + return this.modelPath; + } + + /** + * Gets the LLM prefix. + * + * @return The LLM prefix. + */ + public String getLlmPrefix() + { + return this.llmPrefix; + } + + /** + * Gets the number of Threads used + * @return The number of Threads LLM uses. + */ + public int getNumThreads() + { + return this.numThreads; + } + + /** + * Sets the model tag. + * + * @param modelTag The model tag to set. + */ + public void setModelTag(String modelTag) + { + this.modelTag = modelTag; + } + + /** + * Sets the user tag. + * + * @param userTag The user tag to set. + */ + public void setUserTag(String userTag) + { + this.userTag = userTag; + } + + /** + * Sets the list of stop words. + * + * @param stopWords The list of stop words to set. + */ + public void setStopWords(List stopWords) + { + this.stopWords = stopWords; + } + + /** + * Sets the model path. + * + * @param modelPath The model path to set. + */ + public void setModelPath(String modelPath) + { + this.modelPath = modelPath; + } + + /** + * Sets the LLM prefix. + * + * @param llmPrefix The LLM prefix to set. + */ + public void setLlmPrefix(String llmPrefix) + { + this.llmPrefix = llmPrefix; + } + + /** + * Sets the number of Threads. + * @param numThreads count of threads to use for LLM. + */ + public void setNumThreads(int numThreads) + { + this.numThreads = numThreads; + } +} + diff --git a/src/java/com/arm/llm/Llama.java b/src/java/com/arm/llm/Llama.java deleted file mode 100644 index cce26e50a6d3d30dbe4f784943f4886e5679c2fd..0000000000000000000000000000000000000000 --- a/src/java/com/arm/llm/Llama.java +++ /dev/null @@ -1,380 +0,0 @@ -// -// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates -// -// SPDX-License-Identifier: Apache-2.0 -// - -package com.arm.llm; - -import java.util.List; -import java.util.concurrent.Flow; -import java.util.concurrent.SubmissionPublisher; -import java.util.concurrent.atomic.AtomicBoolean; - -public class Llama extends SubmissionPublisher -{ - static - { - try - { - System.loadLibrary("arm-llm-jni"); - } catch (UnsatisfiedLinkError e) - { - System.err.println("Llama: Failed to load library: arm-llm-jni"); - e.printStackTrace(); - } - } - - private long llmContext = 0; - private long modelPointer = 0; - private String modelTag = ""; - private static final int embeddings = 0; - private static final int tokens = 150; - private static final int sequenceMax = 1; - private static final int nLen = 1024; - private int nCur = 0; - private long batch = 0; - private List stopWords = null; - private String cachedToken = ""; - private String emitToken = ""; - private String llmPrefix = ""; - private int numThreads; - - //ToDo Create a session manager to manage a conversion instead of "evaluatedOnce" - private AtomicBoolean evaluatedOnce = new AtomicBoolean(false); - - // Native method declarations - - /** - * Method for loading LLM model - * - * @param pathToModel file path for loading model - * @return pointer to loaded model - */ - public native long loadModel(String pathToModel); - - /** - * Method for freeing LLM model - * - * @param modelPtr to free model - * @param contextPtr for freeing up LLM context - */ - public native void freeModel(long modelPtr,long contextPtr); - - /** - * Method for getting encode timing - * - * @param contextPtr LLM context pointer to the loaded context - * @return timings in tokens/s for encoding prompt - */ - public native float getEncodeTimings(long contextPtr); - - /** - * Method for getting decode timing - * - * @param contextPtr LLM context pointer to the loaded context - * @return timings in tokens/s for decoding prompt - */ - public native float getDecodeTimings(long contextPtr); - - /** - * Method for getting a new llama context - * - * @param modelPtr to loaded model - * @param numThreads number of threads to use - * @return pointer to LLM context loaded - */ - public native long newContext(long modelPtr, int numThreads); - - /** - * Method for clearing previous chat history from llama - * - * @param context LLM context pointer to the loaded context - */ - private native void kvCacheClear(long context); - - /** - * Method for clearing previous chat history until specified point from llama - * - * @param context LLM context pointer to the loaded context - * @param startPos starting index from which to delete context memory (inclusive) - * @param lastPos the index upto which to clear (exclusive) - */ - private native void kvCacheSeqRm(long context, int startPos, int lastPos); - - /** - * Method for resetting timing information - * - * @param contextPtr LLM context pointer to the loaded context - */ - public native void resetTimings(long contextPtr); - - /** - * Method for getting a new sampler - * - * @param contextPtr LLM context pointer to the loaded context - * @return pointer to sampler - */ - public native long newSampler(long contextPtr); - - /** - * Method to get llmPrefix length in terms of tokens - * - * @param modelPtr pointer to LLM model - * @param textLength length of initial prompt - * @param text LLM Prefix to be encoded - * @param addSpecial bool for optional special character at end of prompt - * @return length of original prompt - */ - public native int getInitialPromptLength(long modelPtr, int textLength, String text, boolean addSpecial); - - /** - * Method to get initializes the llama backend - */ - public native void backendInit(); - - /** - * Method to create new batch - * - * @param numTokens number of allowed tokens in the batch - * @param embeddings number of allowed embeddings in the batch - * @param numSequenceMax number of sequences allowed - * @return newly created batch object - */ - public native long newBatch(int numTokens, int embeddings, int numSequenceMax); - - - /** - * Method to Encode the given text and return number of tokens in prompt - * - * @param text the prompt to be encoded - * @param context pointer to llama_context instance - * @param batch pointer to llama batch of type llama_batch - * @param startPos starting index of positional embeddings from which to populate current question - * @return number of tokens if successful otherwise error code - */ - public native int completionInit( - String text, - long context, - long batch, - int startPos - ); - - /** - * Method to decode answers one by one, once prefill stage is completed - * - * @param context pointer to llama_context instance - * @param batch pointer to llama batch of type llama_batch - * @param nLen max length of context memory to be filled - * @param currentPos starting index of positional embeddings to populate current decoded token - * @return generated token as a string - */ - public native String completionLoop( - long context, - long batch, - int nLen, - int currentPos - ); - - /** - *Method to separate Initialization from constructor - * - *@param llamaConfig type configuration file to load model - */ - public void llmInit(LlamaConfig llamaConfig) - { - this.modelPointer = loadModel(llamaConfig.getModelPath()); - this.numThreads = llamaConfig.getNumThreads(); - this.llmContext = newContext(modelPointer, numThreads); - this.batch = newBatch(tokens, embeddings, sequenceMax); - this.stopWords = llamaConfig.getStopWords(); - this.modelTag = llamaConfig.getModelTag(); - this.llmPrefix = llamaConfig.getLlmPrefix(); - } - - /** - *Method to assing a new subscriber to this publisher - * - *@param subscriber subscriber that will receive published tokens - */ - public void setSubscriber(Flow.Subscriber subscriber) - { - System.out.println("subscribed set from llama"); - this.subscribe(subscriber); - } - - /** - * Method to get response of a query asynchronously - * - * @param Query the prompt asked - */ - public void sendAsync(String Query) - { - - String query = ""; - AtomicBoolean stop = new AtomicBoolean(false); - if (evaluatedOnce.get()) - query = Query + modelTag; - else - query = llmPrefix + Query + modelTag; - nCur += completionInit(query, this.llmContext, this.batch, nCur); - evaluatedOnce.set(true); - while (nCur <= nLen) - { - String s = completionLoop(this.llmContext, this.batch, nCur, nLen); - stop.set(inspectWord(s)); - if (stop.get()) - { - emitToken = ""; - ++nCur; - this.submit(emitToken); - - break; - } - ++nCur; - this.submit(emitToken); - } - } - - /** - * Method to find any stop-Words or partial stop-Word present in current token - * - * @param str current token decoded - * @return boolean for detection of stop word - */ - private boolean inspectWord(String str) - { - boolean stopWordTriggered = false; - String evaluationString = this.cachedToken + str; - // if stopWord is in evaluationString break loop - for (String word : stopWords) - { - if (evaluationString.contains(word)) - { - stopWordTriggered = true; - emitToken = ""; - cachedToken = ""; - return stopWordTriggered; - } - } - emitToken = evaluationString; - for (String word : stopWords) - { - String partialWord = word; - partialWord = partialWord.substring(0, partialWord.length() - 1); - while (!partialWord.isEmpty()) - { - if (evaluationString.endsWith(partialWord)) // if the beginning for stop word coincides with end of emitted token dont emit current token - { - emitToken = ""; - break; - } else - { - partialWord = partialWord.substring(0, partialWord.length() - 1); - } - } - } - this.cachedToken = emitToken.isEmpty() ? evaluationString : ""; - return stopWordTriggered; - } - - /** - * Method to reset conversation history - */ - public void resetContext() - { - - int nPrefix = getInitialPromptLength(this.modelPointer, this.llmPrefix.length(), this.llmPrefix, true); - if (nPrefix < 0) - { - nPrefix = 0; - } - kvCacheSeqRm(this.llmContext, nPrefix, -1); - resetTimings(this.llmContext); - nCur = nPrefix; - } - - /** - * Method to get response of a query synchronously - * - * @param Query the prompt asked - * @return response of LLM - */ - public String send(String Query) - { - String response = ""; - String query = ""; - boolean stop = false; - if (evaluatedOnce.get()) - query = Query + modelTag; - else - query = llmPrefix + Query + modelTag; - nCur += completionInit(query, this.llmContext, this.batch, nCur); - evaluatedOnce.set(true); - while (nCur <= nLen) - { - - String s = completionLoop(this.llmContext, this.batch, nCur, nLen); - stop = inspectWord(s); - if (!stop) - { - response += emitToken; - } else - { - ++nCur; - break; - } - ++nCur; - } - - return response; - } - - /** - * Method to get current encode timings - * - * @return encode timings in tokens/s - */ - public float getEncodeRate() - { - return getEncodeTimings(this.llmContext); - } - - /** - * Method to get current decode timings - * - * @return decode timings in tokens/s - */ - public float getDecodeRate() - { - return getDecodeTimings(this.llmContext); - } - - /** - * Sets the LLM prefix used for query processing - */ - public void setLlmPrefix(String llmPrefix) - { - this.llmPrefix = llmPrefix; - } - - /** - * Sets the LLM ModelTag - */ - public void setLlmModelTag(String newTag) - { - this.modelTag = newTag; - } - /** - * Method to free model from memory - */ - public void freeModel() - { - evaluatedOnce.set(false); - freeModel(this.modelPointer,this.llmContext); - this.close(); // Publisher is closed - } - - -} - diff --git a/src/java/com/arm/llm/LlamaConfig.java b/src/java/com/arm/llm/LlamaConfig.java deleted file mode 100644 index 1cb7ca5a561608f611f5f5b6d01516421d47ff66..0000000000000000000000000000000000000000 --- a/src/java/com/arm/llm/LlamaConfig.java +++ /dev/null @@ -1,129 +0,0 @@ -// -// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates -// -// SPDX-License-Identifier: Apache-2.0 -// - -package com.arm.llm; - -import java.util.List; - -public class LlamaConfig -{ - private static final String LLAMA_MODEL_NAME = "model.gguf"; - private String modelTag; - private String modelPath; - private String llmPrefix; - private List stopWords; - private Integer numThreads; - - public LlamaConfig(String modelTag, List stopWords, String modelPath, String llmPrefix, Integer numThreads) - { - this.modelTag = modelTag; - this.stopWords = stopWords; - this.modelPath = modelPath; - this.llmPrefix = llmPrefix; - this.numThreads = numThreads; - } - - /** - * Gets the number of threads - * - * @return The number of threads - */ - public Integer getNumThreads() - { - return this.numThreads; - } - - /** - * Gets the model tag - * - * @return The model tag - */ - public String getModelTag() - { - return this.modelTag; - } - - /** - * Gets the list of stop words - * - * @return The list of stop words - */ - public List getStopWords() - { - return this.stopWords; - } - - /** - * Gets the model path - * - * @return The model path - */ - public String getModelPath() - { - return this.modelPath; - } - - /** - * Gets the LLM prefix - * - * @return The LLM prefix - */ - public String getLlmPrefix() - { - return this.llmPrefix; - } - - /** - * Sets the number of threads - * - * @param numThreads The number of threads to set - */ - public void setNumThreads(Integer numThreads) - { - this.numThreads = numThreads; - } - - /** - * Sets the model tag - * - * @param modelTag The model tag to set - */ - public void setModelTag(String modelTag) - { - this.modelTag = modelTag; - } - - /** - * Sets the list of stop words - * - * @param stopWords The list of stop words to set - */ - public void setStopWords(List stopWords) - { - this.stopWords = stopWords; - } - - /** - * Sets the model path - * - * @param modelPath The model path to set - */ - public void setModelPath(String modelPath) - { - this.modelPath = modelPath + "/" + LLAMA_MODEL_NAME; - } - - /** - * Sets the LLM prefix - * - * @param llmPrefix The LLM prefix to set - */ - public void setLlmPrefix(String llmPrefix) - { - this.llmPrefix = llmPrefix; - } -} - diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 2dac78d83cd2a35fa52b2d11bd80e1debb5878b6..54886b2ba318740457be040411f5020776ed55e9 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -21,7 +21,14 @@ if (NOT TARGET arm-llm-cpp) endif() message(STATUS "Adding C++ tests") -add_executable(llm-cpp-tests cpp/LLamaTest.cpp) +add_executable(llm-cpp-tests + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/LlmUtils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/LlmTest.cpp) + +if (${LLM_DEP_NAME} STREQUAL "llama.cpp") + set(CONFIG_FILE_NAME "llamaConfig.txt" CACHE STRING "Path to the Llama config file") +endif () + # We pass in the compile definition for location of test # models directory to know where to read the model file @@ -49,8 +56,8 @@ target_include_directories(llm-cpp-tests PUBLIC ${CATCH_DIR}) enable_testing() add_test(NAME llm-cpp-ctest COMMAND llm-cpp-tests) -# Ensure the configuration file is available to both C++ and JNI tests -set(CONFIG_FILE_NAME "llamaConfig.txt") +set(USER_CONFIG_FILE_NAME "LLMUserConfig.txt") + set(CONFIG_FILE_SOURCE "${CMAKE_SOURCE_DIR}/model_configuration_files/${CONFIG_FILE_NAME}") set(CONFIG_FILE_DEST "${CMAKE_BINARY_DIR}/${CONFIG_FILE_NAME}") @@ -63,6 +70,18 @@ configure_file(${CONFIG_FILE_SOURCE} ${CONFIG_FILE_DEST} COPYONLY) target_compile_definitions(llm-cpp-tests PUBLIC CONFIG_FILE_PATH="${CONFIG_FILE_DEST}") +set(USER_CONFIG_FILE_SOURCE "${CMAKE_SOURCE_DIR}/model_configuration_files/${USER_CONFIG_FILE_NAME}") +set(USER_CONFIG_FILE_DEST "${CMAKE_BINARY_DIR}/${USER_CONFIG_FILE_NAME}") + +if(NOT EXISTS ${USER_CONFIG_FILE_SOURCE}) + message(FATAL_ERROR "Configuration file not found: ${USER_CONFIG_FILE_SOURCE}") +endif() + +configure_file(${USER_CONFIG_FILE_SOURCE} ${USER_CONFIG_FILE_DEST} COPYONLY) + +target_compile_definitions(llm-cpp-tests + PUBLIC USER_CONFIG_FILE_PATH="${USER_CONFIG_FILE_DEST}") + # If JNI libs are being built, add tests for these here. if(TARGET arm-llm-jni) message(STATUS "Adding JNI tests") @@ -86,7 +105,7 @@ if(TARGET arm-llm-jni) add_jar( llm-jni-tests ${LLM_JAVA_INTERFACE_SOURCES} - ${CMAKE_CURRENT_SOURCE_DIR}/java/com/arm/LlamaTestJNI.java + java/com/arm/LlmTestJNI.java INCLUDE_JARS ${_jarDependencies}) @@ -102,13 +121,14 @@ if(TARGET arm-llm-jni) string(REPLACE ";" ":" _jarDependenciesTests "${_jarDependencies}") add_test( - NAME llama-jni-ctest + NAME llm-jni-ctest COMMAND ${Java_JAVA_EXECUTABLE} -Djava.library.path=${_jniLibPath} -Dmodel_dir=${TEST_MODELS_DIR} -Dconfig_file=${CONFIG_FILE_DEST} + -Duser_config_file=${USER_CONFIG_FILE_DEST} -cp ${_jarDependenciesTests}:${_jarFile} org.junit.runner.JUnitCore - com.arm.LlamaTestJNI) + com.arm.LlmTestJNI) endif() diff --git a/test/cpp/LLamaTest.cpp b/test/cpp/LLamaTest.cpp deleted file mode 100644 index 32ecc43afa2b9b81e083ab2dc0a6be09b19a5449..0000000000000000000000000000000000000000 --- a/test/cpp/LLamaTest.cpp +++ /dev/null @@ -1,111 +0,0 @@ -// -// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates -// SPDX-License-Identifier: Apache-2.0 -// -#define CATCH_CONFIG_MAIN - -#include -#include "LLM.hpp" -#include "LlamaImpl.hpp" -#include -#include -#include -#include - -// Function to parse the configuration file -std::unordered_map LoadConfig(const std::string &configFilePath) { - std::unordered_map config; - std::ifstream file(configFilePath); - - if (!file.is_open()) { - throw std::runtime_error("Failed to open config file: " + configFilePath); - } - - std::string line; - while (std::getline(file, line)) { - if (line.empty() || line[0] == '#') continue; // Skip empty lines and comments - - size_t delimiterPos = line.find('='); - if (delimiterPos != std::string::npos) { - std::string key = line.substr(0, delimiterPos); - std::string value = line.substr(delimiterPos + 1); - config[key] = value; - } - } - - return config; -} - -//ToDo AFAIC Several variables could also be read-in at compile time e.g. embeddings, nLen etc. -// Variables that need to be set after config file parsing -static constexpr int embeddings = 150; -static constexpr int tokens = 0; -static constexpr int sequenceMax = 1; -std::string testModelsDir = TEST_MODELS_DIR; -std::string modelPath = - testModelsDir + "/model.gguf"; -std::list STOP_WORDS; -std::string llmPrefix; -std::string modelTag; -int nLen = 1024; - -// Function to load configuration before tests -void InitializeConfig() { - std::string configFilePath = CONFIG_FILE_PATH; - std::unordered_map config = LoadConfig(configFilePath); - - llmPrefix = config["llmPrefixDefault"]; - modelTag = config["modelTagDefault"]; - - // Parse stopWordsDefault into a list - std::istringstream stopWordsStream(config["stopWordsDefault"]); - std::string word; - while (std::getline(stopWordsStream, word, ',')) { - STOP_WORDS.push_back(word); - } -} - -// Call InitializeConfig before tests run -struct ConfigInitializer { - ConfigInitializer() { - InitializeConfig(); - } -} configInitializer; // Global instance ensures initialization before tests - -/** - * Simple query->response test - * ToDo Replace with more sophisticated context tests if/when reset context is available in Cpp layer - */ -TEST_CASE("Test Query Response") { - std::string response; - int nCur = 0; - const std::string question = "What is the capital of France?" + modelTag; - const std::string prefixedQuestion = llmPrefix + question; - - LLM llm; - auto *model = llm.LoadModel(modelPath.c_str()); - llm.BackendInit(); - - auto *context = llm.NewContext(model, 2); - auto batch = llm.NewBatch(embeddings, tokens, sequenceMax); - nCur = llm.CompletionInit(prefixedQuestion, context, &batch, 0); - - while (nCur <= nLen) { - std::string s = llm.CompletionLoop(context, &batch, nCur, nLen); - if ((std::find(STOP_WORDS.begin(), STOP_WORDS.end(), s) != STOP_WORDS.end())) { - break; - } - response += s; - } - CHECK(response.find("Paris") != std::string::npos); -} - -/** - * Test Load Empty Model returns nullptr - */ -TEST_CASE("Test Load Empty Model") { - std::string emptyModelPath; - LLM llm; - auto *model = llm.LoadModel(emptyModelPath.c_str()); - CHECK(model == nullptr); -} diff --git a/test/cpp/LlmTest.cpp b/test/cpp/LlmTest.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d0f520bc66ac95ddf272ce8417ac71fcd5a083aa --- /dev/null +++ b/test/cpp/LlmTest.cpp @@ -0,0 +1,90 @@ +// +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#define CATCH_CONFIG_MAIN + +#include "catch.hpp" + +#include "LlmImpl.hpp" +#include "LlmUtils.hpp" + +#include +#include + +// Function to create the configuration file from CONFIG_FILE_PATH +void SetupTestConfig(std::stringstream& stopWordsStream, + LlmConfig* configTest, + std::list& STOP_WORDS) +{ + std::string configFilePath = CONFIG_FILE_PATH; + std::string userConfigFilePath = USER_CONFIG_FILE_PATH; + auto config = Llm::Test::Utils::LoadConfig(configFilePath); + stopWordsStream.str(""); + stopWordsStream.clear(); + STOP_WORDS.clear(); + stopWordsStream << config["stopWords"]; + std::string word; + while (std::getline(stopWordsStream, word, ',')) { + STOP_WORDS.push_back(word); + } + std::string testModelsDir = TEST_MODELS_DIR; + std::string modelPath = testModelsDir + "/" + config["llmModelName"]; + config["modelPath"] = modelPath; + auto userConfig = Llm::Test::Utils::LoadUserConfig(userConfigFilePath); + *configTest = Llm::Test::Utils::GetConfig(config, userConfig); + configTest->SetModelPath(modelPath); +} +/** + * Simple query->response test + * ToDo Replace with more sophisticated context tests if/when reset context is available in Cpp + * layer + */ +TEST_CASE("Test Llm-Wrapper class") +{ + LlmConfig configTest{}; + std::stringstream stopWordsStream; + std::list STOP_WORDS; + SetupTestConfig(stopWordsStream, &configTest, STOP_WORDS); + + std::string response; + std::string question = "What is the capital of France?" + configTest.GetModelTag(); + std::string prefixedQuestion = configTest.GetLlmPrefix() + question; + LLM llm; + + SECTION("Simple Query Response") + { + llm.LlmInit(configTest); + llm.Encode(prefixedQuestion); + bool stop = false; + + while (llm.GetChatProgress() < 100) { + std::string s = llm.NextToken(); + for (auto& stopWord : STOP_WORDS) { + if (s.find(stopWord) != std::string::npos) { + stop = true; + break; + } + } + if (stop) + { + break; + } + response += s; + } + CHECK(response.find("Paris") != std::string::npos); + } + + /** + * Test Load Empty Model returns nullptr + */ + SECTION("Test Load Empty Model") + { + std::string emptyString; + configTest.SetModelPath(emptyString); + REQUIRE_THROWS(llm.LlmInit(configTest)); + } + + llm.FreeLlm(); +} diff --git a/test/cpp/LlmUtils.cpp b/test/cpp/LlmUtils.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3d90bae37d704f071464539be6de310a45e8eb0e --- /dev/null +++ b/test/cpp/LlmUtils.cpp @@ -0,0 +1,99 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#include "LlmUtils.hpp" +#include +#include + +namespace Llm::Test::Utils { + +std::unordered_map LoadConfig(const std::string& configFilePath) +{ + std::unordered_map config; + std::ifstream file(configFilePath); + + if (!file.is_open()) { + throw std::runtime_error("Failed to open config file: " + configFilePath); + } + + std::string line; + while (std::getline(file, line)) { + if (line.empty() || line[0] == '#') + continue; // Skip empty lines and comments + + size_t delimiterPos = line.find('='); + if (delimiterPos != std::string::npos) { + std::string key = line.substr(0, delimiterPos); + std::string value = line.substr(delimiterPos + 1); + config[key] = value; + } + } + + return config; +} + +std::unordered_map LoadUserConfig(const std::string& userConfigFilePath) +{ + std::unordered_map config; + std::ifstream file(userConfigFilePath); + + if (!file.is_open()) { + throw std::runtime_error("Failed to open config file: " + userConfigFilePath); + } + + std::string line; + while (std::getline(file, line)) { + if (line.empty() || line[0] == '#') + continue; // Skip empty lines and comments + + size_t delimiterPos = line.find('='); + if (delimiterPos != std::string::npos) { + std::string key = line.substr(0, delimiterPos); + std::string value = line.substr(delimiterPos + 1); + // sanitize the numerical values + try { + size_t numSize; + int numericalValue = std::stoi(value, &numSize); + // ensure only numbers are present + if (numSize != value.length()) { + throw std::invalid_argument("Extra characters after number"); + } + config[key] = numericalValue; + } catch (const std::invalid_argument& e) { + std::cerr << "Invalid input: " << value << " for " << key << "\n"; + } catch (const std::out_of_range& e) { + std::cerr << "Out of range input: " << value << " for " << key << "\n"; + } + } + } + + return config; +} + +LlmConfig GetConfig(std::unordered_map config, + std::unordered_map userConfig) +{ + if (config.find("modelPath") == config.end()) + throw std::runtime_error("Missing required parameter: modelPath"); + if (config.find("modelTag") == config.end()) + throw std::runtime_error("Missing required parameter: modelTag"); + if (config.find("llmPrefix") == config.end()) + throw std::runtime_error("Missing required parameter: llmPrefix"); + + if (userConfig.find("batchSize") == userConfig.end()) + throw std::runtime_error("Missing required parameter: batchSize"); + if (userConfig.find("numThreads") == userConfig.end()) + throw std::runtime_error("Missing required parameter: numThreads"); + if (config.find("stopWords") == config.end()) + throw std::runtime_error("Missing required parameter: stopWords"); + + return LlmConfig(config.at("modelTag"), + config.at("modelPath"), + config.at("llmPrefix"), + userConfig.at("numThreads"), + userConfig.at("batchSize")); +} + +} /* namespace Llm::Test::Utils */ diff --git a/test/cpp/LlmUtils.hpp b/test/cpp/LlmUtils.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2b798143863c2769a08e354430edc345649d82b2 --- /dev/null +++ b/test/cpp/LlmUtils.hpp @@ -0,0 +1,36 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#ifndef LLM_TEST_UTILS_HPP +#define LLM_TEST_UTILS_HPP + +#include "LlmConfig.hpp" + +#include +#include + +namespace Llm::Test::Utils { + +std::unordered_map LoadConfig(const std::string& configFilePath); + +/** + * Method to load LLM-Config params from file in "param=value(int)" format + * @param userConfigFilePath + * @return a dictionary of user-defined params. + */ +std::unordered_map LoadUserConfig(const std::string& userConfigFilePath); + +/** + * Method to create LlmConfig Instance from model-configuration and user-configuration dictionaries + * @param config Model Configuration file which contains model related details like llmPrefix, + * modelTag etc. + * @return A LlmConfig file which can be used to construct an LLM instance. + */ +LlmConfig GetConfig(std::unordered_map config, + std::unordered_map userConfig); + +} /* namespace Llm::Test::Utils */ + +#endif /* LLM_TEST_UTILS_HPP */ diff --git a/test/java/com/arm/LlamaTestJNI.java b/test/java/com/arm/LlamaTestJNI.java deleted file mode 100644 index c259d14c06090550b9890d6c1b9d51a30876475e..0000000000000000000000000000000000000000 --- a/test/java/com/arm/LlamaTestJNI.java +++ /dev/null @@ -1,432 +0,0 @@ -// -// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates -// -// SPDX-License-Identifier: Apache-2.0 -// - -package com.arm; - -import static org.junit.Assert.*; -import static org.junit.Assume.assumeTrue; - -import org.junit.Test; -import org.junit.BeforeClass; - -import com.arm.llm.Llama; -import com.arm.llm.LlamaConfig; - -import java.io.*; -import java.util.*; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.Flow; -import java.util.concurrent.TimeUnit; - - -public class LlamaTestJNI { - private static final String modelDir = System.getProperty("model_dir"); - private static final String configFilePath = System.getProperty("config_file"); - private static final Map variables = new HashMap<>(); - private static final String LLAMA_MODEL_NAME = "model.gguf"; - private static final int numThreads = 4; - - // Timeout for subscriber latch await in seconds - private static final long LATCH_TIMEOUT_SECONDS = 5; - - private static String modelTag = ""; - private static String modelPath = ""; - private static String llmPrefix = ""; - private static List stopWords = new ArrayList<>(); - - /** - * Instead of matching the actual response to expected response, - * check whether the response contains the salient parts of expected response. - * Pass true to check match and false to assert absence of salient parts for negative tests. - */ - private static void checkLlamaMatch(String response, String expectedResponse, boolean checkMatch) { - boolean matches = response.contains(expectedResponse); - if (!matches) { - System.out.println("Response mismatch: response={" + response + "} expected={" + expectedResponse + "}"); - } - if (checkMatch) { - assertTrue(matches); - } else { - assertFalse(matches); - } - } - - /** - * Loads variables from the specified configuration file. - * - * @param filePath Path to the configuration file. - * @throws IOException If an I/O error occurs. - */ - private static void loadVariables(String filePath) throws IOException { - try (BufferedReader br = new BufferedReader(new FileReader(filePath))) { - String line; - while ((line = br.readLine()) != null) { - if (!line.contains("=")) continue; - String[] parts = line.split("=", 2); - if (parts[0].trim().equals("stopWordsDefault")) { - stopWords.clear(); // Ensure no duplicates on reloading - stopWords.addAll(Arrays.asList(parts[1].split(","))); - } else { - variables.put(parts[0].trim(), parts[1].trim()); - } - } - } catch (FileNotFoundException e) { - throw new IOException("Configuration file not found: " + filePath); - } catch (Exception e) { - throw new IOException("Error reading configuration file: " + e.getMessage()); - } - } - - @BeforeClass - public static void classSetup() throws IOException { - if (modelDir == null) throw new RuntimeException("System property 'model_dir' is not set!"); - if (configFilePath == null) - throw new RuntimeException("System property 'config_file' is not set!"); - - loadVariables(configFilePath); - modelTag = variables.get("modelTagDefault"); - llmPrefix = variables.get("llmPrefixDefault"); - modelPath = modelDir + "/" + LLAMA_MODEL_NAME; - } - - /** - * A test implementation of Flow.Subscriber that collects tokens from asynchronous operations. - * It accumulates received tokens in a list and uses a CountDownLatch to signal when the end-of-stream - * token ("") is received, allowing tests to wait for completion. - */ - static class TestSubscriber implements Flow.Subscriber { - - private Flow.Subscription subscription; - private final List receivedTokens = new ArrayList<>(); - // Latch to signal when has been received. - private final CountDownLatch latch = new CountDownLatch(1); - @Override - public void onSubscribe(Flow.Subscription subscription) { - this.subscription = subscription; - // Request an unlimited number of tokens. - subscription.request(Long.MAX_VALUE); - } - @Override - public void onNext(String token) { - receivedTokens.add(token); - // If the token indicates end-of-stream, count down the latch. - if ("".equals(token)) { - latch.countDown(); - } - } - @Override - public void onError(Throwable throwable) { - // In case of error, count down the latch so the test can proceed. - latch.countDown(); - } - @Override - public void onComplete() { - latch.countDown(); - } - public List getReceivedTokens() { - return receivedTokens; - } - public boolean await(long timeout, TimeUnit unit) throws InterruptedException { - return latch.await(timeout, unit); - } - } - - @Test - public void testAsyncPublishing() throws Exception { - // Create and initialize the Llama instance with test config using global variables. - Llama llama = new Llama(); - LlamaConfig llamaConfig = new LlamaConfig(modelTag, stopWords, modelPath, llmPrefix, numThreads); - llama.llmInit(llamaConfig); - // Set up our test subscriber. - TestSubscriber subscriber = new TestSubscriber(); - llama.setSubscriber(subscriber); - llama.sendAsync("what is 2 + 2"); - // Wait up to LATCH_TIMEOUT_SECONDS seconds for the subscriber to receive the token. - boolean completed = subscriber.await(LATCH_TIMEOUT_SECONDS, TimeUnit.SECONDS); - assertTrue("Subscriber received token within timeout", completed); - // Retrieve and check the received tokens. - List tokens = subscriber.getReceivedTokens(); - assertEquals("Last token should be ", "", tokens.get(tokens.size() - 1)); - StringBuilder tokenString = new StringBuilder(); - for (String token : tokens) { - tokenString.append(token); - } - // Check that the tokens contain the number "4". - assertTrue("Tokens should contain the number 4", tokenString.toString().contains("4")); - // Clean up the model resources. - llama.freeModel(); - } - - @Test - public void testAsyncInferenceWithoutContextReset() throws Exception { - // Create and initialize the Llama instance with global config. - LlamaConfig llamaConfig = new LlamaConfig(modelTag, stopWords, modelPath, llmPrefix, numThreads); - Llama llama = new Llama(); - llama.llmInit(llamaConfig); - TestSubscriber subscriber1 = new TestSubscriber(); - llama.setSubscriber(subscriber1); - String question1 = "What is the capital of Morocco?"; - llama.sendAsync(question1); - // Wait up to LATCH_TIMEOUT_SECONDS seconds for the subscriber to receive the token. - boolean completed1 = subscriber1.await(LATCH_TIMEOUT_SECONDS, TimeUnit.SECONDS); - assertTrue("Subscriber received token within timeout for question1", completed1); - // Retrieve and check the received tokens. - List tokens1 = subscriber1.getReceivedTokens(); - StringBuilder tokenString1 = new StringBuilder(); - for (String token : tokens1) { - tokenString1.append(token); - } - String response1 = tokenString1.toString(); - checkLlamaMatch(response1, "Rabat", true); - TestSubscriber subscriber2 = new TestSubscriber(); - llama.setSubscriber(subscriber2); - String question2 = "What languages do they speak there?"; - llama.sendAsync(question2); - boolean completed2 = subscriber2.await(LATCH_TIMEOUT_SECONDS, TimeUnit.SECONDS); - assertTrue("Subscriber received token within timeout for question2", completed2); - List tokens2 = subscriber2.getReceivedTokens(); - StringBuilder tokenString2 = new StringBuilder(); - for (String token : tokens2) { - tokenString2.append(token); - } - String response2 = tokenString2.toString(); - checkLlamaMatch(response2, "Arabic", true); - // Clean up the model resources. - llama.freeModel(); - } - - @Test - public void testAsyncInferenceRecoversAfterContextReset() throws Exception { - LlamaConfig llamaConfig = new LlamaConfig(modelTag, stopWords, modelPath, llmPrefix, numThreads); - Llama llama = new Llama(); - llama.llmInit(llamaConfig); - String question1 = "What is the capital of Morocco?"; - TestSubscriber subscriber1 = new TestSubscriber(); - llama.setSubscriber(subscriber1); - llama.sendAsync(question1); - boolean completed1 = subscriber1.await(LATCH_TIMEOUT_SECONDS, TimeUnit.SECONDS); - assertTrue("Subscriber should receive token for question1", completed1); - List tokens1 = subscriber1.getReceivedTokens(); - StringBuilder tokenString1 = new StringBuilder(); - for (String token : tokens1) { - tokenString1.append(token); - } - String response1 = tokenString1.toString(); - checkLlamaMatch(response1, "Rabat", true); - // Reset context before the next question. - llama.resetContext(); - String question2 = "What languages do they speak there?"; - TestSubscriber subscriber2 = new TestSubscriber(); - llama.setSubscriber(subscriber2); - llama.sendAsync(question2); - boolean completed2 = subscriber2.await(LATCH_TIMEOUT_SECONDS, TimeUnit.SECONDS); - assertTrue("Subscriber should receive token for question2", completed2); - List tokens2 = subscriber2.getReceivedTokens(); - StringBuilder tokenString2 = new StringBuilder(); - for (String token : tokens2) { - tokenString2.append(token); - } - String response2 = tokenString2.toString(); - checkLlamaMatch(response2, "Arabic", false); - llama.resetContext(); - TestSubscriber subscriber3 = new TestSubscriber(); - llama.setSubscriber(subscriber3); - llama.sendAsync(question1); - boolean completed3 = subscriber3.await(LATCH_TIMEOUT_SECONDS, TimeUnit.SECONDS); - assertTrue("Subscriber should receive token for question3", completed3); - List tokens3 = subscriber3.getReceivedTokens(); - StringBuilder tokenString3 = new StringBuilder(); - for (String token : tokens3) { - tokenString3.append(token); - } - String response3 = tokenString3.toString(); - checkLlamaMatch(response3, "Rabat", true); - // Fourth Question: Ask second question again. - TestSubscriber subscriber4 = new TestSubscriber(); - llama.setSubscriber(subscriber4); - llama.sendAsync(question2); - boolean completed4 = subscriber4.await(LATCH_TIMEOUT_SECONDS, TimeUnit.SECONDS); - assertTrue("Subscriber should receive token for question4", completed4); - List tokens4 = subscriber4.getReceivedTokens(); - StringBuilder tokenString4 = new StringBuilder(); - for (String token : tokens4) { - tokenString4.append(token); - } - String response4 = tokenString4.toString(); - checkLlamaMatch(response4, "Arabic", true); - checkLlamaMatch(response4, "French", true); - // Free model resources. - llama.freeModel(); - } - - - @Test - public void testConfigLoading() { - LlamaConfig llamaConfig = new LlamaConfig(modelTag, stopWords, modelPath, llmPrefix, numThreads); - assertTrue("Model tag is not empty", !llamaConfig.getModelTag().isEmpty()); - assertTrue("LLM prefix is not empty", !llamaConfig.getLlmPrefix().isEmpty()); - assertTrue("Stop words list is not empty", !llamaConfig.getStopWords().isEmpty()); - } - - @Test - public void testLlmPrefixSetting() { - LlamaConfig llamaConfig = new LlamaConfig(modelTag, stopWords, modelPath, llmPrefix, numThreads); - Llama llama = new Llama(); - llama.llmInit(llamaConfig); - - String newModelTag = ("Ferdia"); - String newPrefix = "Transcript of a dialog, where the User interacts with an AI Assistant named " + newModelTag + - ". " + newModelTag + - " is helpful, polite, honest, good at writing and answers honestly with a maximum of two sentences. User:"; - - llama.setLlmModelTag(newModelTag); - llama.setLlmPrefix(newPrefix); - - String question = "What is your name?"; - String response = llama.send(question); - checkLlamaMatch(response, "Ferdia", true); - llama.freeModel(); - } - - @Test - public void testInferenceWithContextReset() { - LlamaConfig llamaConfig = new LlamaConfig(modelTag, stopWords, modelPath, llmPrefix, numThreads); - Llama llama = new Llama(); - llama.llmInit(llamaConfig); - - String question1 = "What is the capital of Morocco?"; - String response1 = llama.send(question1); - checkLlamaMatch(response1, "Rabat", true); - - // Resetting context should cause model to forget what country is being referred to - llama.resetContext(); - - String question2 = "What languages do they speak there?"; - String response2 = llama.send(question2); - checkLlamaMatch(response2, "Arabic", false); - - llama.freeModel(); - } - - @Test - public void testInferenceWithoutContextReset() { - LlamaConfig llamaConfig = new LlamaConfig(modelTag, stopWords, modelPath, llmPrefix, numThreads); - Llama llama = new Llama(); - llama.llmInit(llamaConfig); - - String question1 = "What is the capital of Morocco?"; - String response1 = llama.send(question1); - checkLlamaMatch(response1, "Rabat", true); - - String question2 = "What languages do they speak there?"; - String response2 = llama.send(question2); - checkLlamaMatch(response2, "Arabic", true); - - llama.freeModel(); - } - - @Test - public void testInferenceHandlesEmptyQuestion() { - LlamaConfig llamaConfig = new LlamaConfig(modelTag, stopWords, modelPath, llmPrefix, numThreads); - Llama llama = new Llama(); - llama.llmInit(llamaConfig); - - String question1 = "What is the capital of Morocco?"; - String response1 = llama.send(question1); - checkLlamaMatch(response1, "Rabat", true); - - // Send an empty prompt to simulate blank recordings or non-speech tokens being returned by speech recognition; - // then ask follow-up questions to ensure previous context persists when an empty prompt is injected in the conversation. - String emptyResponse = llama.send(""); // ToDo may revisit this to add an expected answer - - String question2 = "What languages do they speak there?"; - String response2 = llama.send(question2); - checkLlamaMatch(response2, "Arabic", true); - checkLlamaMatch(response2, "French", true); - - llama.freeModel(); - } - - @Test - public void testMangoSubtractionLongConversation() { - - LlamaConfig llamaConfig = new LlamaConfig(modelTag, stopWords, modelPath, llmPrefix, numThreads); - Llama llama = new Llama(); - llama.llmInit(llamaConfig); - - // 35 was determined to be upper limit for storing context but to avoid excessively long test runtime we cap at 20 - int originalMangoes = 5; - int mangoes = originalMangoes; - - // Set the initial ground truth in the conversation. - String initialContext = "There are " + originalMangoes + " mangoes."; - String initResponse = llama.send(initialContext); - String originalQuery = "How many mangoes were there originally?"; - String subtractQuery = "Subtract 1 mango."; - - // **Assert that the model acknowledges the initial count of mangoes.** - checkLlamaMatch(initResponse, String.valueOf(originalMangoes), true); - - // Loop to subtract 1 mango each iteration until reaching 0. - for (int i = 1; i < originalMangoes; i++) { - - // Query to subtract one mango - String subtractionResponse = llama.send(subtractQuery); - mangoes -= 1; // Update our expected count - checkLlamaMatch(subtractionResponse, String.valueOf(mangoes), true); - - // Test if model still recalls the starting number - if (i == originalMangoes - 1) { - String response = llama.send(originalQuery); - checkLlamaMatch(response, String.valueOf(originalMangoes), true); - llama.resetContext(); - } - - } - - String postResetResponse = llama.send(originalQuery); - checkLlamaMatch(postResetResponse, String.valueOf(originalMangoes), false); - llama.freeModel(); - } - - @Test - public void testInferenceRecoversAfterContextReset() { - // Get model directory and config file path from system properties - String modelDir = System.getProperty("model_dir"); - String configFilePath = System.getProperty("config_file"); - if (modelDir == null || configFilePath == null) { - throw new RuntimeException("System properties for model_dir or config_file are not set!"); - } - - LlamaConfig llamaConfig = new LlamaConfig(modelTag, stopWords, modelPath, llmPrefix, numThreads); - // Initialize Llama with the loaded config - Llama llama = new Llama(); - llama.llmInit(llamaConfig); - - // First Question - String question1 = "What is the capital of Morocco?"; - String response1 = llama.send(question1); - checkLlamaMatch(response1, "Rabat", true); - // Reset Context before second question - llama.resetContext(); - - // Second Question (After Reset) - String question2 = "What languages do they speak there?"; - String response2 = llama.send(question2); - checkLlamaMatch(response2, "Arabic", false); - // Ask First Question Again. Note an additional reset is required to prevent the generic answer from previous question affecting new topic. - llama.resetContext(); - String response3 = llama.send(question1); - - checkLlamaMatch(response3, "Rabat", true); - String response4 = llama.send(question2); - checkLlamaMatch(response4, "Arabic", true); - checkLlamaMatch(response4, "French", true); - - // Free model after use - llama.freeModel(); - } -} diff --git a/test/java/com/arm/LlmTestJNI.java b/test/java/com/arm/LlmTestJNI.java new file mode 100644 index 0000000000000000000000000000000000000000..f1e5f4a0d8602efa871b1fdd0f20c9bb57fcd48c --- /dev/null +++ b/test/java/com/arm/LlmTestJNI.java @@ -0,0 +1,264 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +package com.arm; + +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.assertFalse; +import static org.junit.Assume.assumeTrue; + +import org.junit.Test; +import org.junit.BeforeClass; + +import com.arm.Llm; +import com.arm.LlmConfig; + +import java.io.*; +import java.util.*; + +public class LlmTestJNI { + private static final String modelDir = System.getProperty("model_dir"); + private static final String configFilePath = System.getProperty("config_file"); + private static final String userConfigFilePath = System.getProperty("user_config_file"); + private static final Map variables = new HashMap<>(); + private static int numThreads = 4; + private static String modelTag = ""; + private static String userTag = ""; + private static String modelPath = ""; + private static String llmPrefix = ""; + private static List stopWords = new ArrayList(); + + /** + * Instead of matching LLM's response to expected response, + * check whether the response contains the salient parts of expected response. + * Pass true to check match and false to assert absence of salient parts for negative tests. + */ + private static void checkLlmMatch(String response, String expectedResponse, boolean checkMatch) { + boolean matches = response.contains(expectedResponse); + if (checkMatch) { + assertTrue("Response mismatch: response={" + response + "} should contain={" + expectedResponse + "}", matches); + } else { + assertFalse("Response mismatch: response={" + response + "} shouldn't contain={" + expectedResponse + "}", matches); + } + } + + /** + * Loads variables from the specified configuration file. + * + * @param filePath Path to the configuration file. + * @throws IOException If an I/O error occurs. + */ + private static void loadVariables(String filePath) throws IOException { + try (BufferedReader br = new BufferedReader(new FileReader(filePath))) { + String line; + while ((line = br.readLine()) != null) { + if (!line.contains("=")) continue; + String[] parts = line.split("=", 2); + if (parts[0].trim().equals("stopWords")) { + stopWords.clear(); // Ensure no duplicates on reloading + stopWords.addAll(Arrays.asList(parts[1].split(","))); + } else { + variables.put(parts[0].trim(), parts[1].trim()); + } + } + } catch (FileNotFoundException e) { + throw new IOException("Configuration file not found: " + filePath); + } catch (Exception e) { + throw new IOException("Error reading configuration file: " + e.getMessage()); + } + } + + @BeforeClass + public static void classSetup() throws IOException { + if (modelDir == null) throw new RuntimeException("System property 'model_dir' is not set!"); + if (configFilePath == null) + throw new RuntimeException("System property 'config_file' is not set!"); + + loadVariables(configFilePath); + modelTag = variables.get("modelTag"); + userTag = variables.getOrDefault("userTag",""); + llmPrefix = variables.get("llmPrefix"); + modelPath = modelDir + "/" + variables.get("llmModelName"); + loadVariables(userConfigFilePath); + try{ + numThreads = Integer.valueOf(variables.getOrDefault("numThreads","4")); + } + catch(NumberFormatException e){ + System.out.println("Number of Threads parameter not found in UserConfiguration File"); + numThreads = 4; + } + + } + + @Test + public void testConfigLoading() { + LlmConfig llmConfig = new LlmConfig(modelTag, stopWords, modelPath, llmPrefix,userTag,numThreads); + assertTrue("Model tag is not empty", !llmConfig.getModelTag().isEmpty()); + assertTrue("LLM prefix is not empty", !llmConfig.getLlmPrefix().isEmpty()); + assertTrue("Stop words list is not empty", !llmConfig.getStopWords().isEmpty()); + } + + @Test + public void testLlmPrefixSetting() { + LlmConfig llmConfig = new LlmConfig(modelTag, stopWords, modelPath, llmPrefix,userTag); + Llm llm = new Llm(); + llm.llmInit(llmConfig); + + String newModelTag = ("Ferdia"); + String newPrefix = "Transcript of a dialog, where the User interacts with an AI Assistant named " + newModelTag + + ". " + newModelTag + + " is helpful, polite, honest, good at writing and answers honestly with a maximum of two sentences. User:"; + + llm.setLlmModelTag(newModelTag); + llm.setLlmPrefix(newPrefix); + + String question = "What is your name?"; + String response = llm.send(question); + checkLlmMatch(response, "Ferdia", true); + llm.freeModel(); + } + + @Test + public void testInferenceWithContextReset() { + LlmConfig llmConfig = new LlmConfig(modelTag, stopWords, modelPath, llmPrefix,userTag,numThreads); + Llm llm = new Llm(); + llm.llmInit(llmConfig); + + String question1 = "What is the capital of the country, Morocco?"; + String response1 = llm.send(question1); + checkLlmMatch(response1, "Rabat", true); + + // Resetting context should cause model to forget what country is being referred to + llm.resetContext(); + + String question2 = "What languages do they speak there?"; + String response2 = llm.send(question2); + checkLlmMatch(response2, "Arabic", false); + + llm.freeModel(); + } + + @Test + public void testInferenceWithoutContextReset() { + LlmConfig llmConfig = new LlmConfig(modelTag, stopWords, modelPath, llmPrefix,userTag,numThreads); + Llm llm = new Llm(); + llm.llmInit(llmConfig); + + String question1 = "What is the capital of the country, Morocco?"; + String response1 = llm.send(question1); + checkLlmMatch(response1, "Rabat", true); + + String question2 = "What languages do they speak there?"; + String response2 = llm.send(question2); + checkLlmMatch(response2, "Arabic", true); + + llm.freeModel(); + } + + @Test + public void testInferenceHandlesEmptyQuestion() { + LlmConfig llmConfig = new LlmConfig(modelTag, stopWords, modelPath, llmPrefix,userTag,numThreads); + Llm llm = new Llm(); + llm.llmInit(llmConfig); + + String question1 = "What is the capital of the country, Morocco?"; + String response1 = llm.send(question1); + checkLlmMatch(response1, "Rabat", true); + + // Send an empty prompt to simulate blank recordings or non-speech tokens being returned by speech recognition; + // then ask follow-up questions to ensure previous context persists when an empty prompt is injected in the conversation. + String emptyResponse = llm.send(""); + + checkLlmMatch(emptyResponse, "Rabat", true); + + String question2 = "What languages do they speak there?"; + String response2 = llm.send(question2); + checkLlmMatch(response2, "Arabic", true); + + llm.freeModel(); + } + + @Test + public void testMangoSubtractionLongConversation() { + + LlmConfig llmConfig = new LlmConfig(modelTag, stopWords, modelPath, llmPrefix,userTag,numThreads); + Llm llm = new Llm(); + llm.llmInit(llmConfig); + + // 35 was determined to be upper limit for storing context but to avoid excessively long test runtime we cap at 20 + int originalMangoes = 5; + int mangoes = originalMangoes; + + // Set the initial ground truth in the conversation. + String initialContext = "There are " + originalMangoes + " mangoes in a basket."; + String initResponse = llm.send(initialContext); + String originalQuery = "How many mangoes did we start with? "; + String subtractQuery = "Subtract 1 mango. How many mangoes are left now? "; + + // **Assert that the model acknowledges the initial count of mangoes.** + checkLlmMatch(initResponse, String.valueOf(originalMangoes), true); + + // Loop to subtract 1 mango each iteration until reaching 0. + for (int i = 1; i < originalMangoes; i++) { + + // Query to subtract one mango + String subtractionResponse = llm.send(subtractQuery); + mangoes -= 1; // Update our expected count + checkLlmMatch(subtractionResponse, String.valueOf(mangoes), true); + + // Test if model still recalls the starting number + if (i == originalMangoes - 1) { + String response = llm.send(originalQuery); + checkLlmMatch(response, String.valueOf(originalMangoes), true); + llm.resetContext(); + } + + } + + String postResetResponse = llm.send(originalQuery); + checkLlmMatch(postResetResponse, String.valueOf(originalMangoes), false); + llm.freeModel(); + } + + @Test + public void testInferenceRecoversAfterContextReset() { + // Get model directory and config file path from system properties + String modelDir = System.getProperty("model_dir"); + String configFilePath = System.getProperty("config_file"); + if (modelDir == null || configFilePath == null) { + throw new RuntimeException("System properties for model_dir or config_file are not set!"); + } + + LlmConfig llmConfig = new LlmConfig(modelTag, stopWords, modelPath, llmPrefix,userTag,numThreads); + Llm llm = new Llm(); + llm.llmInit(llmConfig); + + // First Question + String question1 = "What is the capital of the country, Morocco?"; + String response1 = llm.send(question1); + checkLlmMatch(response1, "Rabat", true); + // Reset Context before second question + llm.resetContext(); + + // Second Question (After Reset) + String question2 = "What languages do they speak there?"; + String response2 = llm.send(question2); + checkLlmMatch(response2, "Arabic", false); + // Ask First Question Again. Note an additional reset is required to prevent the generic answer + // from previous question affecting new topic. + llm.resetContext(); + String response3 = llm.send(question1); + + checkLlmMatch(response3, "Rabat", true); + String response4 = llm.send(question2); + + checkLlmMatch(response4, "Arabic", true); + checkLlmMatch(response4, "French", true); + + // Free model after use + llm.freeModel(); + } +}