From 52b07b0ebcea619164eb79c80c5b8e617aff4d37 Mon Sep 17 00:00:00 2001 From: Liam Date: Tue, 29 Apr 2025 18:17:24 +0100 Subject: [PATCH] MLECO-6048: Async mode tests for llama.cpp Support code added to allow testing of subscriber and token emission in async mode Signed-off-by: Liam Change-Id: I23f0e3919a525538e8f31965eeba6ea44e99ea80 --- test/java/com/arm/LlamaTestJNI.java | 183 +++++++++++++++++++++++++++- 1 file changed, 180 insertions(+), 3 deletions(-) diff --git a/test/java/com/arm/LlamaTestJNI.java b/test/java/com/arm/LlamaTestJNI.java index 47ac741..c259d14 100644 --- a/test/java/com/arm/LlamaTestJNI.java +++ b/test/java/com/arm/LlamaTestJNI.java @@ -6,8 +6,7 @@ package com.arm; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.assertFalse; +import static org.junit.Assert.*; import static org.junit.Assume.assumeTrue; import org.junit.Test; @@ -18,6 +17,10 @@ import com.arm.llm.LlamaConfig; import java.io.*; import java.util.*; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Flow; +import java.util.concurrent.TimeUnit; + public class LlamaTestJNI { private static final String modelDir = System.getProperty("model_dir"); @@ -26,10 +29,13 @@ public class LlamaTestJNI { private static final String LLAMA_MODEL_NAME = "model.gguf"; private static final int numThreads = 4; + // Timeout for subscriber latch await in seconds + private static final long LATCH_TIMEOUT_SECONDS = 5; + private static String modelTag = ""; private static String modelPath = ""; private static String llmPrefix = ""; - private static List stopWords = new ArrayList(); + private static List stopWords = new ArrayList<>(); /** * Instead of matching the actual response to expected response, @@ -86,6 +92,177 @@ public class LlamaTestJNI { modelPath = modelDir + "/" + LLAMA_MODEL_NAME; } + /** + * A test implementation of Flow.Subscriber that collects tokens from asynchronous operations. + * It accumulates received tokens in a list and uses a CountDownLatch to signal when the end-of-stream + * token ("") is received, allowing tests to wait for completion. + */ + static class TestSubscriber implements Flow.Subscriber { + + private Flow.Subscription subscription; + private final List receivedTokens = new ArrayList<>(); + // Latch to signal when has been received. + private final CountDownLatch latch = new CountDownLatch(1); + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.subscription = subscription; + // Request an unlimited number of tokens. + subscription.request(Long.MAX_VALUE); + } + @Override + public void onNext(String token) { + receivedTokens.add(token); + // If the token indicates end-of-stream, count down the latch. + if ("".equals(token)) { + latch.countDown(); + } + } + @Override + public void onError(Throwable throwable) { + // In case of error, count down the latch so the test can proceed. + latch.countDown(); + } + @Override + public void onComplete() { + latch.countDown(); + } + public List getReceivedTokens() { + return receivedTokens; + } + public boolean await(long timeout, TimeUnit unit) throws InterruptedException { + return latch.await(timeout, unit); + } + } + + @Test + public void testAsyncPublishing() throws Exception { + // Create and initialize the Llama instance with test config using global variables. + Llama llama = new Llama(); + LlamaConfig llamaConfig = new LlamaConfig(modelTag, stopWords, modelPath, llmPrefix, numThreads); + llama.llmInit(llamaConfig); + // Set up our test subscriber. + TestSubscriber subscriber = new TestSubscriber(); + llama.setSubscriber(subscriber); + llama.sendAsync("what is 2 + 2"); + // Wait up to LATCH_TIMEOUT_SECONDS seconds for the subscriber to receive the token. + boolean completed = subscriber.await(LATCH_TIMEOUT_SECONDS, TimeUnit.SECONDS); + assertTrue("Subscriber received token within timeout", completed); + // Retrieve and check the received tokens. + List tokens = subscriber.getReceivedTokens(); + assertEquals("Last token should be ", "", tokens.get(tokens.size() - 1)); + StringBuilder tokenString = new StringBuilder(); + for (String token : tokens) { + tokenString.append(token); + } + // Check that the tokens contain the number "4". + assertTrue("Tokens should contain the number 4", tokenString.toString().contains("4")); + // Clean up the model resources. + llama.freeModel(); + } + + @Test + public void testAsyncInferenceWithoutContextReset() throws Exception { + // Create and initialize the Llama instance with global config. + LlamaConfig llamaConfig = new LlamaConfig(modelTag, stopWords, modelPath, llmPrefix, numThreads); + Llama llama = new Llama(); + llama.llmInit(llamaConfig); + TestSubscriber subscriber1 = new TestSubscriber(); + llama.setSubscriber(subscriber1); + String question1 = "What is the capital of Morocco?"; + llama.sendAsync(question1); + // Wait up to LATCH_TIMEOUT_SECONDS seconds for the subscriber to receive the token. + boolean completed1 = subscriber1.await(LATCH_TIMEOUT_SECONDS, TimeUnit.SECONDS); + assertTrue("Subscriber received token within timeout for question1", completed1); + // Retrieve and check the received tokens. + List tokens1 = subscriber1.getReceivedTokens(); + StringBuilder tokenString1 = new StringBuilder(); + for (String token : tokens1) { + tokenString1.append(token); + } + String response1 = tokenString1.toString(); + checkLlamaMatch(response1, "Rabat", true); + TestSubscriber subscriber2 = new TestSubscriber(); + llama.setSubscriber(subscriber2); + String question2 = "What languages do they speak there?"; + llama.sendAsync(question2); + boolean completed2 = subscriber2.await(LATCH_TIMEOUT_SECONDS, TimeUnit.SECONDS); + assertTrue("Subscriber received token within timeout for question2", completed2); + List tokens2 = subscriber2.getReceivedTokens(); + StringBuilder tokenString2 = new StringBuilder(); + for (String token : tokens2) { + tokenString2.append(token); + } + String response2 = tokenString2.toString(); + checkLlamaMatch(response2, "Arabic", true); + // Clean up the model resources. + llama.freeModel(); + } + + @Test + public void testAsyncInferenceRecoversAfterContextReset() throws Exception { + LlamaConfig llamaConfig = new LlamaConfig(modelTag, stopWords, modelPath, llmPrefix, numThreads); + Llama llama = new Llama(); + llama.llmInit(llamaConfig); + String question1 = "What is the capital of Morocco?"; + TestSubscriber subscriber1 = new TestSubscriber(); + llama.setSubscriber(subscriber1); + llama.sendAsync(question1); + boolean completed1 = subscriber1.await(LATCH_TIMEOUT_SECONDS, TimeUnit.SECONDS); + assertTrue("Subscriber should receive token for question1", completed1); + List tokens1 = subscriber1.getReceivedTokens(); + StringBuilder tokenString1 = new StringBuilder(); + for (String token : tokens1) { + tokenString1.append(token); + } + String response1 = tokenString1.toString(); + checkLlamaMatch(response1, "Rabat", true); + // Reset context before the next question. + llama.resetContext(); + String question2 = "What languages do they speak there?"; + TestSubscriber subscriber2 = new TestSubscriber(); + llama.setSubscriber(subscriber2); + llama.sendAsync(question2); + boolean completed2 = subscriber2.await(LATCH_TIMEOUT_SECONDS, TimeUnit.SECONDS); + assertTrue("Subscriber should receive token for question2", completed2); + List tokens2 = subscriber2.getReceivedTokens(); + StringBuilder tokenString2 = new StringBuilder(); + for (String token : tokens2) { + tokenString2.append(token); + } + String response2 = tokenString2.toString(); + checkLlamaMatch(response2, "Arabic", false); + llama.resetContext(); + TestSubscriber subscriber3 = new TestSubscriber(); + llama.setSubscriber(subscriber3); + llama.sendAsync(question1); + boolean completed3 = subscriber3.await(LATCH_TIMEOUT_SECONDS, TimeUnit.SECONDS); + assertTrue("Subscriber should receive token for question3", completed3); + List tokens3 = subscriber3.getReceivedTokens(); + StringBuilder tokenString3 = new StringBuilder(); + for (String token : tokens3) { + tokenString3.append(token); + } + String response3 = tokenString3.toString(); + checkLlamaMatch(response3, "Rabat", true); + // Fourth Question: Ask second question again. + TestSubscriber subscriber4 = new TestSubscriber(); + llama.setSubscriber(subscriber4); + llama.sendAsync(question2); + boolean completed4 = subscriber4.await(LATCH_TIMEOUT_SECONDS, TimeUnit.SECONDS); + assertTrue("Subscriber should receive token for question4", completed4); + List tokens4 = subscriber4.getReceivedTokens(); + StringBuilder tokenString4 = new StringBuilder(); + for (String token : tokens4) { + tokenString4.append(token); + } + String response4 = tokenString4.toString(); + checkLlamaMatch(response4, "Arabic", true); + checkLlamaMatch(response4, "French", true); + // Free model resources. + llama.freeModel(); + } + + @Test public void testConfigLoading() { LlamaConfig llamaConfig = new LlamaConfig(modelTag, stopWords, modelPath, llmPrefix, numThreads); -- GitLab