From ee8f4da7332459f4f1226f7f4ca0c7c242e51520 Mon Sep 17 00:00:00 2001 From: Yunus Kalkan Date: Wed, 30 Apr 2025 10:19:53 +0100 Subject: [PATCH] MLECO-6009: [STT] Thread configuration * Introduce whisperConfig.java for whisper parameters * Add option to configure Whisper parameters via user config or defaults * Improve JNI methods Change-Id: I4a0fea3034ce7d71bc28d5f9587a33b82b0cba2c Signed-off-by: Yunus Kalkan --- src/cpp/include/STT.hpp | 32 ++- src/cpp/whisper_cpp/include/WhisperImpl.hpp | 77 ++++--- src/cpp/whisper_cpp/jni/Whisper.cpp | 37 +++- src/java/CMakeLists.txt | 2 +- src/java/com/arm/stt/Whisper.java | 45 +++- src/java/com/arm/stt/WhisperConfig.java | 226 ++++++++++++++++++++ test/cpp/WhisperTest.cpp | 17 +- test/java/com/arm/stt/WhisperTestApp.java | 27 ++- 8 files changed, 413 insertions(+), 50 deletions(-) create mode 100644 src/java/com/arm/stt/WhisperConfig.java diff --git a/src/cpp/include/STT.hpp b/src/cpp/include/STT.hpp index 66aa1c8..f3dcd9d 100644 --- a/src/cpp/include/STT.hpp +++ b/src/cpp/include/STT.hpp @@ -17,7 +17,30 @@ template class STT { private: + T stt; public: + /** + * Initializes the Whisper parameters with the specified settings. + * @param printRealTime whether to print partial decoding results in real-time + * @param printProgress whether to print progress information + * @param timeStamps whether to include timestamps in the transcription + * @param printSpecial whether to include special tokens (e.g., markers) in the output + * @param translate whether to translate the transcription to English + * @param language the language code for transcription (e.g., "en", "fr", etc.) + * @param numThreads the number of CPU threads to use for transcription + * @param offsetMs an initial time offset (in milliseconds) for the transcription + * @param noContext whether to disable reusing context between segments + * @param singleSegment whether to transcribe the entire audio in a single segment + */ + void InitParams(const bool printRealtime, const bool printProgress, const bool timeStamps, + const bool printSpecial, const bool translate, const char *language, + const int numThreads, const int offsetMs, const bool noContext, + const bool singleSegment) + { + stt.InitParams(printRealtime, printProgress, timeStamps, printSpecial, translate, + language, numThreads, offsetMs, noContext, singleSegment); + } + /** * Function to load the chosen STT model and to init the context * @tparam P stt context type @@ -27,7 +50,7 @@ public: template P *InitContext(const char *pathToModel) { - return ((T *) this)->InitContext(pathToModel); + return stt.InitContext(pathToModel); } /** @@ -38,21 +61,20 @@ public: template void FreeContext(P* contextPtr) { - ((T *) this)->FreeContext(contextPtr); + stt.FreeContext(contextPtr); } /** * The entire transcription inference loop * @tparam P stt context type * @param contextPtr stt context pointer - * @param numThreads number of threads to use * @param audioData audio data to transcribe * @param audioDataLength length of the Audio data supplied */ template - std::string FullTranscribe(P* contextPtr, int numThreads, float* audioData, int audioDataLength) + std::string FullTranscribe(P* contextPtr, float* audioData, int audioDataLength) { - return ((T *) this)->FullTranscribe(contextPtr, numThreads, audioData, audioDataLength); + return stt.FullTranscribe(contextPtr, audioData, audioDataLength); } }; #endif //STT_STT_HPP diff --git a/src/cpp/whisper_cpp/include/WhisperImpl.hpp b/src/cpp/whisper_cpp/include/WhisperImpl.hpp index 0fec156..1c97004 100644 --- a/src/cpp/whisper_cpp/include/WhisperImpl.hpp +++ b/src/cpp/whisper_cpp/include/WhisperImpl.hpp @@ -17,8 +17,12 @@ * @brief Whisper Implementation of our STT API * */ -class WhisperImpl : public STT { +class WhisperImpl { private: + + std::string strLang{"en"}; + struct whisper_full_params whisperParams{}; + /** * Function to retrieve the total number of text segments * @param contextPtr whisper_context pointer @@ -41,34 +45,41 @@ private: return text; } +public: + WhisperImpl() = default; + /** - * The transcription inference loop, inspired by an existing whisper.cpp example - * @param contextPtr whisper_context pointer - * @param numThreads number of threads to use - * @param audioDataPtr pointer to audio data to transcribe - * @param audioDataLength length of the audio data array - */ - void Transcribe(whisper_context* contextPtr, const int numThreads, const float* audioDataPtr, const int audioDataLength) + * Initializes the Whisper parameters with the specified settings. + * @param printRealTime whether to print partial decoding results in real-time + * @param printProgress whether to print progress information + * @param timeStamps whether to include timestamps in the transcription + * @param printSpecial whether to include special tokens (e.g., markers) in the output + * @param translate whether to translate the transcription to English + * @param language the language code for transcription (e.g., "en", "fr", etc.) + * @param numThreads the number of CPU threads to use for transcription + * @param offsetMs an initial time offset (in milliseconds) for the transcription + * @param noContext whether to disable reusing context between segments + * @param singleSegment whether to transcribe the entire audio in a single segment + */ + void InitParams(const bool printRealtime, const bool printProgress, const bool printTimestamps, + const bool printSpecial, const bool translate, const char *language, + const int numThreads, const int offsetMs, const bool noContext, + const bool singleSegment) { - whisper_full_params params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); - params.print_realtime = true; - params.print_progress = false; - params.print_timestamps = true; - params.print_special = false; - params.translate = false; - params.language = "en"; - params.n_threads = numThreads; - params.offset_ms = 0; - params.no_context = true; - params.single_segment = false; - - whisper_full(contextPtr, params, &audioDataPtr[0], audioDataLength); - whisper_reset_timings(contextPtr); - whisper_print_timings(contextPtr); - } + this->strLang = std::string(language); + this->whisperParams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); + this->whisperParams.print_realtime = printRealtime; + this->whisperParams.print_progress = printProgress; + this->whisperParams.print_timestamps = printTimestamps; + this->whisperParams.print_special = printSpecial; + this->whisperParams.translate = translate; + this->whisperParams.language = strLang.c_str(); + this->whisperParams.n_threads = numThreads; + this->whisperParams.offset_ms = offsetMs; + this->whisperParams.no_context = noContext; + this->whisperParams.single_segment = singleSegment; -public: - WhisperImpl() = default; + } /** * Function to load the chosen STT model and to init the context @@ -92,15 +103,19 @@ public: /** * The full transcription inference loop, and retrieval of all text segments - * Taken from whisper.cpp/examples/whisper.android/lib/src/main/jni/whisper/jni.c and slightly modified + * Taken from whisper.cpp/examples/whisper.android/lib/src/main/jni/whisper/jni.c and slightly + * modified * @param contextPtr whisper_context pointer - * @param numThreads number of threads to use - * @param audioDataPtr pointer to audio data to transcribe + * @param audioDataPtr pointer to audio data to transcribe * @param audioDataLength length of the audio data array */ - std::string FullTranscribe(whisper_context* contextPtr, const int numThreads, const float* audioDataPtr, const int audioDataLength) + std::string FullTranscribe(whisper_context* contextPtr, const float* audioDataPtr, + const int audioDataLength) { - Transcribe(contextPtr, numThreads, audioDataPtr, audioDataLength); + + whisper_full(contextPtr, whisperParams, &audioDataPtr[0], audioDataLength); + whisper_reset_timings(contextPtr); + whisper_print_timings(contextPtr); int count = GetTextSegmentCount(contextPtr); diff --git a/src/cpp/whisper_cpp/jni/Whisper.cpp b/src/cpp/whisper_cpp/jni/Whisper.cpp index 7ef2bd0..484abcd 100644 --- a/src/cpp/whisper_cpp/jni/Whisper.cpp +++ b/src/cpp/whisper_cpp/jni/Whisper.cpp @@ -13,7 +13,36 @@ extern "C" { #endif // Instantiating a Whisper type STT implementation -STT stt; +static STT stt; + +/** + * Initialize whisper parameters + * + * @param env JNI environment + * @param jprintRealTime whether to print partial decoding results in real-time + * @param jprintProgress whether to print progress information + * @param jtimeStamps whether to include timestamps in the transcription + * @param jprintSpecial whether to include special tokens (e.g., markers) in the output + * @param jtranslate whether to translate the transcription to English + * @param jlanguage the language code for transcription (e.g., "en", "fr", etc.) + * @param jnumThreads the number of CPU threads to use for transcription + * @param joffsetMs an initial time offset (in milliseconds) for the transcription + * @param jnoContext whether to disable reusing context between segments + * @param jsingleSegment whether to transcribe the entire audio in a single segment + */ +JNIEXPORT void JNICALL +Java_com_arm_stt_Whisper_initParams(JNIEnv *env, jobject, jboolean jprintRealtime, + jboolean jprintProgress, jboolean jtimeStamps, + jboolean jprintSpecial, jboolean jtranslate, jstring jlanguage, + jint jnumThreads, jint joffsetMs, jboolean jnoContext, + jboolean jsingleSegment) +{ + const char *language_chars = env->GetStringUTFChars(jlanguage, nullptr); + stt.InitParams(jprintRealtime, jprintProgress, jtimeStamps, jprintSpecial, + jtranslate, language_chars, jnumThreads, joffsetMs, + jnoContext, jsingleSegment); + env->ReleaseStringUTFChars(jlanguage, language_chars); +} /** * Initialize whisper context @@ -47,19 +76,17 @@ JNIEXPORT void JNICALL Java_com_arm_stt_Whisper_freeContext * Full transcribe function to * @param env JNI environment * @param contextPtr pointer to whisper context - * @param numThreads number of threads to use * @param audioData audio data to transcribe * @return full transcription */ JNIEXPORT jstring JNICALL Java_com_arm_stt_Whisper_fullTranscribe - (JNIEnv *env, jobject, jlong contextPtr, - jint numThreads, jfloatArray audioData) + (JNIEnv *env, jobject, jlong contextPtr, jfloatArray audioData) { auto *context = reinterpret_cast(contextPtr); jfloat *audio_data_arr = env->GetFloatArrayElements(audioData, nullptr); const jsize audio_data_length = env->GetArrayLength(audioData); - const std::string transcribed = stt.FullTranscribe(context, numThreads, audio_data_arr, audio_data_length); + const std::string transcribed = stt.FullTranscribe(context, audio_data_arr, audio_data_length); env->ReleaseFloatArrayElements(audioData, audio_data_arr, JNI_ABORT); return env->NewStringUTF(transcribed.c_str()); diff --git a/src/java/CMakeLists.txt b/src/java/CMakeLists.txt index d5c367a..8f30505 100644 --- a/src/java/CMakeLists.txt +++ b/src/java/CMakeLists.txt @@ -16,7 +16,7 @@ project(arm-stt-java-prj add_library(arm-stt-java INTERFACE) if (${STT_DEP_NAME} STREQUAL "whisper.cpp") - target_sources(arm-stt-java INTERFACE com/arm/stt/Whisper.java) + target_sources(arm-stt-java INTERFACE com/arm/stt/Whisper.java com/arm/stt/WhisperConfig.java) add_dependencies(arm-stt-java arm-stt-jni) else() message(FATAL_ERROR "${STT_DEP_NAME} is currently not supported :(") diff --git a/src/java/com/arm/stt/Whisper.java b/src/java/com/arm/stt/Whisper.java index dc30f8b..3abfbd4 100644 --- a/src/java/com/arm/stt/Whisper.java +++ b/src/java/com/arm/stt/Whisper.java @@ -24,6 +24,48 @@ public class Whisper { */ public native long initContext(String modelPath); + /** + * Function to extracts parameters from WhisperConfig object and + * run the private InitParams function to initialize the parameters + * + * @param whisperConfig the configuration object containing Whisper parameter settings + */ + public void initParameters(WhisperConfig whisperConfig) + { + boolean printRealTime = whisperConfig.isPrintRealTime(); + boolean printProgress = whisperConfig.isPrintProgress(); + boolean timeStamps = whisperConfig.isTimeStamps(); + boolean printSpecial = whisperConfig.isPrintSpecial(); + boolean translate = whisperConfig.isTranslate(); + String language = whisperConfig.getLanguage(); + int numThreads = whisperConfig.getNumThreads(); + int offsetMs = whisperConfig.getOffsetMs(); + boolean noContext = whisperConfig.isNoContext(); + boolean singleSegment =whisperConfig.isSingleSegment(); + + initParams(printRealTime, printProgress, timeStamps, printSpecial, translate, language, + numThreads, offsetMs, noContext, singleSegment); + } + + /** + * Initializes the native Whisper parameters with the specified settings. + * + * @param printRealTime whether to print partial decoding results in real-time + * @param printProgress whether to print progress information + * @param timeStamps whether to include timestamps in the transcription + * @param printSpecial whether to include special tokens (e.g., markers) in the output + * @param translate whether to translate the transcription to English + * @param language the language code for transcription (e.g., "en", "fr", etc.) + * @param numThreads the number of CPU threads to use for transcription + * @param offsetMs an initial time offset (in milliseconds) for the transcription + * @param noContext whether to disable reusing context between segments + * @param singleSegment whether to transcribe the entire audio in a single segment + */ + private native void initParams(boolean printRealTime, boolean printProgress, boolean timeStamps, + boolean printSpecial, boolean translate, String language, + int numThreads, int offsetMs, boolean noContext, + boolean singleSegment); + /** * Function to free the previously initialised whisper_context * @@ -35,9 +77,8 @@ public class Whisper { * Function to run the entire transcription inference loop * * @param contextPtr pointer to the context object previously initialised - * @param numThreads number of threads to use * @param audioData audio data to transcribe * @return transcribed string object */ - public native String fullTranscribe(long contextPtr, int numThreads, float[] audioData); + public native String fullTranscribe(long contextPtr, float[] audioData); } diff --git a/src/java/com/arm/stt/WhisperConfig.java b/src/java/com/arm/stt/WhisperConfig.java new file mode 100644 index 0000000..a038f4c --- /dev/null +++ b/src/java/com/arm/stt/WhisperConfig.java @@ -0,0 +1,226 @@ + // + // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates + // + // SPDX-License-Identifier: Apache-2.0 + // + +package com.arm.stt; + +public class WhisperConfig { + + private boolean printRealTime; + private boolean printProgress; + private boolean timeStamps; + private boolean printSpecial; + private boolean translate; + private String language; + private Integer numThreads; + private Integer offsetMs; + private boolean noContext; + private boolean singleSegment; + + public WhisperConfig(boolean printRealTime, boolean printProgress, boolean timeStamps, + boolean printSpecial, boolean translate, String language, + int numThreads, int offsetMs, boolean noContext, boolean singleSegment) + { + this.printRealTime = printRealTime; + this.printProgress = printProgress; + this.timeStamps = timeStamps; + this.printSpecial = printSpecial; + this.translate = translate; + this.language = language; + this.numThreads = numThreads; + this.offsetMs = offsetMs; + this.noContext = noContext; + this.singleSegment = singleSegment; + + } + + public WhisperConfig() { + + } + + /** + * Gets the number of threads to use. + * + * @return the number of threads + */ + public Integer getNumThreads() { + return numThreads; + } + + /** + * Sets the number of threads to use + * + * @param numThreads the number of threads + */ + public void setNumThreads(Integer numThreads) { + this.numThreads = numThreads; + } + + /** + * Checks if real-time transcription printing is enabled + * + * @return true if enabled, false otherwise + */ + public boolean isPrintRealTime() { + return printRealTime; + } + + + /** + * Enables or disables real-time transcription printing + * + * @param printRealTime true to enable, false to disable + */ + public void setPrintRealTime(boolean printRealTime) { + this.printRealTime = printRealTime; + } + + + /** + * Checks if progress printing is enabled + * + * @return true if enabled, false otherwise + */ + public boolean isPrintProgress() { + return printProgress; + } + + /** + * Enables or disables progress printing + * + * @param printProgress true to enable, false to disable + */ + public void setPrintProgress(boolean printProgress) { + this.printProgress = printProgress; + } + + /** + * Checks if timestamps are included in the transcription + * + * @return true if included, false otherwise + */ + public boolean isTimeStamps() { + return timeStamps; + } + + + /** + * Enables or disables inclusion of timestamps + * + * @param timeStamps true to include, false to exclude + */ + public void setTimeStamps(boolean timeStamps) { + this.timeStamps = timeStamps; + } + + /** + * Checks if special characters are printed in the output + * + * @return true if printed, false otherwise + */ + public boolean isPrintSpecial() { + return printSpecial; + } + + + /** + * Enables or disables printing of special characters + * + * @param printSpecial true to enable, false to disable + */ + public void setPrintSpecial(boolean printSpecial) { + this.printSpecial = printSpecial; + } + + /** + * Checks if translation to English is enabled + * + * @return true if enabled, false otherwise + */ + public boolean isTranslate() { + return translate; + } + + /** + * Enables or disables translation to English + * + * @param translate true to enable, false to disable + */ + public void setTranslate(boolean translate) { + this.translate = translate; + } + + /** + * Gets the language code for transcription + * + * @return the language code + */ + public String getLanguage() { + return language; + } + + /** + * Sets the language code for transcription + * + * @param language the language code (e.g., "en", "es") + */ + public void setLanguage(String language) { + this.language = language; + } + + /** + * Gets the offset in milliseconds to start processing from + * + * @return the offset in milliseconds + */ + public Integer getOffsetMs() { + return offsetMs; + } + + /** + * Sets the offset in milliseconds to start processing from. + * + * @param offsetMs the offset in milliseconds + */ + public void setOffsetMs(Integer offsetMs) { + this.offsetMs = offsetMs; + } + + /** + * Checks if context from previous segments is disabled. + * + * @return true if context is disabled, false otherwise + */ + public boolean isNoContext() { + return noContext; + } + + /** + * Enables or disables context from previous segments. + * + * @param noContext true to disable, false to enable + */ + public void setNoContext(boolean noContext) { + this.noContext = noContext; + } + + /** + * Checks if only a single segment should be returned. + * + * @return true if only one segment is returned, false otherwise + */ + public boolean isSingleSegment() { + return singleSegment; + } + + /** + * Enables or disables returning only a single segment. + * + * @param singleSegment true to enable, false to disable + */ + public void setSingleSegment(boolean singleSegment) { + this.singleSegment = singleSegment; + } +} diff --git a/test/cpp/WhisperTest.cpp b/test/cpp/WhisperTest.cpp index 7d30a63..62f9b0a 100644 --- a/test/cpp/WhisperTest.cpp +++ b/test/cpp/WhisperTest.cpp @@ -39,8 +39,21 @@ TEST_CASE("Test audio file float representation to text") auto* context = stt.InitContext(modelPath.c_str()); std::vector audioData = ReadAudioData(audioDataPath); - constexpr int threads = 2; - const std::string transcribed = stt.FullTranscribe(context, threads, &audioData[0], audioData.size()); + const bool printRealtime = true; + const bool printProgress = false; + const bool timeStamps = true; + const bool printSpecial = false; + const bool translate = false; + const char *language = "en"; + const int numThreads = 2; + const int offsetMs = 0; + const bool noContext = true; + const bool singleSegment = false; + + stt.InitParams(printRealtime, printProgress, timeStamps, printSpecial, translate, language, + numThreads, offsetMs, noContext, singleSegment); + + const std::string transcribed = stt.FullTranscribe(context, &audioData[0], audioData.size()); CHECK(transcribed == " And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country."); } diff --git a/test/java/com/arm/stt/WhisperTestApp.java b/test/java/com/arm/stt/WhisperTestApp.java index c582ad1..3a34504 100644 --- a/test/java/com/arm/stt/WhisperTestApp.java +++ b/test/java/com/arm/stt/WhisperTestApp.java @@ -13,8 +13,8 @@ import java.io.*; import java.util.*; /** - * The WhisperTestApp is used to run a simple speech to text execution using the native Java functions described - * in Whisper.java, using a known audio input. + * The WhisperTestApp is used to run a simple speech to text execution using the native Java + * functions described in Whisper.java, using a known audio input. */ public class WhisperTestApp { @@ -25,10 +25,29 @@ public class WhisperTestApp { String testDataPath = System.getProperty("test_data_dir"); long context = whisper.initContext(modelPath + "/model.bin"); float[] audioData = readCSV(testDataPath + "/audioData.csv"); - String transcribed = whisper.fullTranscribe(context, 2, audioData); + + boolean printRealtime=true; + boolean printProgress=false; + boolean timeStamps=true; + boolean printSpecial=false; + boolean translate=false; + String language="en"; + int numThreads=2; + int offsetMs=0; + boolean noContext=true; + boolean singleSegment=false; + + WhisperConfig whisperConfig = new WhisperConfig(printRealtime, printProgress, timeStamps, + printSpecial, translate, language, + numThreads, offsetMs, noContext, + singleSegment); + + whisper.initParameters(whisperConfig); + String transcribed = whisper.fullTranscribe(context, audioData); String expected = " And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country."; - assertTrue("Expected: [" + expected + "] but was [" + transcribed + "]", transcribed.equals(expected)); + assertTrue("Expected: [" + expected + "] but was [" + transcribed + "]", + transcribed.equals(expected)); } /** -- GitLab