From a4098cc5cc34f2328421f63aee401a09e6afe009 Mon Sep 17 00:00:00 2001 From: Emil Ohlsson Date: Thu, 28 Nov 2024 13:27:50 +0100 Subject: [PATCH] Align test terminology with kernel terminology The test code mostly uses the description `nt_nt` to indicate that LHS and RHS are non-transposed. This doesn't match the naming scheme used by the kernels, which is `kxn` and `nxk`. This change aligns testing code with the kernels in this naming regard. Signed-off-by: Emil Ohlsson --- test/reference/matmul.cpp | 19 ++++++++----------- test/reference/matmul.hpp | 8 ++++---- .../matmul_clamp_f32_bf16p_bf16p_test.cpp | 8 ++++---- ...matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp | 4 ++-- .../matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp | 8 ++++---- ...atmul_clamp_f32_qsi8d32p_qsi4c32p_test.cpp | 2 +- test/tests/matmul_test.cpp | 12 ++++++------ 7 files changed, 29 insertions(+), 32 deletions(-) diff --git a/test/reference/matmul.cpp b/test/reference/matmul.cpp index 381cf1d5..96d8f10f 100644 --- a/test/reference/matmul.cpp +++ b/test/reference/matmul.cpp @@ -188,7 +188,7 @@ std::vector matmul( template < typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale, typename RhsZeroPoint, typename Bias, typename IntAcc, typename DstData> -std::vector matmul_clamp_nt_t( +std::vector matmul_clamp_nxk( size_t m, size_t n, size_t k, // const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // @@ -241,22 +241,21 @@ std::vector matmul_clamp_nt_t( return dst; } -template std::vector matmul_clamp_nt_t( +template std::vector matmul_clamp_nxk( size_t m, size_t n, size_t k, // const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // const void* biases, // float min_value, float max_value); -template std::vector -matmul_clamp_nt_t( +template std::vector matmul_clamp_nxk( size_t m, size_t n, size_t k, // const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // const void* biases, // float min_value, float max_value); -template std::vector matmul_clamp_nt_t( +template std::vector matmul_clamp_nxk( size_t m, size_t n, size_t k, // const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // @@ -266,7 +265,7 @@ template std::vector matmul_clamp_nt_t -std::vector matmul_clamp_nt_nt( +std::vector matmul_clamp_kxn( size_t m, size_t n, size_t k, // const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // @@ -319,23 +318,21 @@ std::vector matmul_clamp_nt_nt( return dst; } -template std::vector matmul_clamp_nt_nt( +template std::vector matmul_clamp_kxn( size_t m, size_t n, size_t k, // const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // const void* biases, // float min_value, float max_value); -template std::vector -matmul_clamp_nt_nt( +template std::vector matmul_clamp_kxn( size_t m, size_t n, size_t k, // const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // const void* biases, // float min_value, float max_value); -template std::vector -matmul_clamp_nt_nt( +template std::vector matmul_clamp_kxn( size_t m, size_t n, size_t k, // const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // diff --git a/test/reference/matmul.hpp b/test/reference/matmul.hpp index 88a0729f..0e7348f1 100644 --- a/test/reference/matmul.hpp +++ b/test/reference/matmul.hpp @@ -67,7 +67,7 @@ std::vector matmul( /// Matrix multiplication with quantized input and floating-point output. /// -/// The LHS matrix is non-transposed and the RHS matrix is transposed. +/// The RHS matrix is transposed. /// /// @tparam LhsData The data type of the LHS matrix. /// @tparam LhsScale The data type of the quantization scales of the LHS matrix. @@ -98,7 +98,7 @@ std::vector matmul( template < typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale, typename RhsZeroPoint, typename Bias, typename IntAcc, typename DstData> -std::vector matmul_clamp_nt_t( +std::vector matmul_clamp_nxk( size_t m, size_t n, size_t k, // const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // @@ -107,7 +107,7 @@ std::vector matmul_clamp_nt_t( /// Matrix multiplication with quantized input and floating-point output. /// -/// The LHS matrix is non-transposed and the RHS matrix is non-transposed. +/// The RHS matrix is non-transposed. /// /// @tparam LhsData The data type of the LHS matrix. /// @tparam LhsScale The data type of the quantization scales of the LHS matrix. @@ -138,7 +138,7 @@ std::vector matmul_clamp_nt_t( template < typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale, typename RhsZeroPoint, typename Bias, typename IntAcc, typename DstData> -std::vector matmul_clamp_nt_nt( +std::vector matmul_clamp_kxn( size_t m, size_t n, size_t k, // const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // diff --git a/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp index 7d2693d4..5690de9c 100644 --- a/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp +++ b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp @@ -43,7 +43,7 @@ namespace kai::test { namespace { const std::array gemm_methods = { MatMulMethod{ - .name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla", + .name = "matmul_kxn_f32_bf16p_bf16p_8x12_neon_mla", .m0 = 8, .n0 = 12, @@ -89,7 +89,7 @@ const std::array gemm_methods = { .fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, }, MatMulMethod{ - .name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla_opt_bias", + .name = "matmul_kxn_f32_bf16p_bf16p_8x12_neon_mla_opt_bias", .m0 = 8, .n0 = 12, @@ -137,7 +137,7 @@ const std::array gemm_methods = { const std::array gemv_methods = { MatMulMethod{ - .name = "matmul_nt_nt_f32_bf16p_bf16p_1x36_neon_dot", + .name = "matmul_kxn_f32_bf16p_bf16p_1x36_neon_dot", .m0 = 1, .n0 = 12, @@ -183,7 +183,7 @@ const std::array gemv_methods = { .fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, }, MatMulMethod{ - .name = "matmul_nt_nt_f32_bf16p_bf16p_1x36_neon_dot_opt_bias", + .name = "matmul_kxn_f32_bf16p_bf16p_1x36_neon_dot_opt_bias", .m0 = 1, .n0 = 12, diff --git a/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp b/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp index 0155d7d8..2fc02d50 100644 --- a/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp +++ b/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp @@ -89,7 +89,7 @@ TEST_P(MatMulTest_f32_qmatmul_clamp_f32_qai8dxp_qsi4c32p, EndToEnd_RHS_nxk) { const auto [ref_rhs_qsi4, ref_rhs_scales] = quantize_symmetric_per_block_dynamic(ref_rhs.data(), N, K, bl); - const auto ref_dst = matmul_clamp_nt_t( + const auto ref_dst = matmul_clamp_nxk( M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), K, ref_rhs_qsi4.data(), ref_rhs_scales.data(), nullptr, bl, nullptr, std::numeric_limits::lowest(), std::numeric_limits::max()); @@ -186,7 +186,7 @@ TEST_P(MatMulTest_f32_qmatmul_clamp_f32_qai8dxp_qsi4c32p, EndToEnd_RHS_kxn) { ref_rhs_qsi4_transposed.data(), N, K, ref_rhs_qsi4_nxk_stride, ref_rhs_qsi4_kxn_stride, ref_rhs_qsi4_kxn_size_bytes); - const auto ref_dst = matmul_clamp_nt_nt( + const auto ref_dst = matmul_clamp_kxn( M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), K, ref_rhs_qsi4.data(), ref_rhs_scales.data(), nullptr, bl, nullptr, std::numeric_limits::lowest(), std::numeric_limits::max()); diff --git a/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp b/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp index b245e7f5..2c814bb6 100644 --- a/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp +++ b/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp @@ -86,7 +86,7 @@ TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_nxk_qsi4cx) { const auto [ref_rhs_qsi4, ref_rhs_scales] = quantize_symmetric_per_block_dynamic(ref_rhs.data(), N, K, K); - const auto ref_dst = matmul_clamp_nt_t( + const auto ref_dst = matmul_clamp_nxk( M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), K, ref_rhs_qsi4.data(), ref_rhs_scales.data(), nullptr, K, ref_biases.data(), std::numeric_limits::lowest(), std::numeric_limits::max()); @@ -165,7 +165,7 @@ TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_nxk_qsu4cx) { const auto [ref_rhs_qsi4, ref_rhs_scales] = quantize_symmetric_per_block_dynamic(ref_rhs.data(), N, K, K); - const auto ref_dst = matmul_clamp_nt_t( + const auto ref_dst = matmul_clamp_nxk( M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), K, ref_rhs_qsi4.data(), ref_rhs_scales.data(), nullptr, K, ref_biases.data(), std::numeric_limits::lowest(), std::numeric_limits::max()); @@ -257,7 +257,7 @@ TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_kxn_qsi4cx) { ref_rhs_qsi4_transposed.data(), N, K, ref_rhs_qsi4_nxk_stride, ref_rhs_qsi4_kxn_stride, ref_rhs_qsi4_kxn_size_bytes); - const auto ref_dst = matmul_clamp_nt_nt( + const auto ref_dst = matmul_clamp_kxn( M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), K, ref_rhs_qsi4.data(), ref_rhs_scales.data(), nullptr, K, ref_biases.data(), std::numeric_limits::lowest(), std::numeric_limits::max()); @@ -346,7 +346,7 @@ TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_kxn_qsu4cx) { ref_rhs_qsi4_transposed.data(), N, K, ref_rhs_qsi4_nxk_stride, ref_rhs_qsi4_kxn_stride, ref_rhs_qsi4_kxn_size_bytes); - const auto ref_dst = matmul_clamp_nt_nt( + const auto ref_dst = matmul_clamp_kxn( M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), K, ref_rhs_qsi4.data(), ref_rhs_scales.data(), nullptr, K, ref_biases.data(), std::numeric_limits::lowest(), std::numeric_limits::max()); diff --git a/test/tests/matmul_clamp_f32_qsi8d32p_qsi4c32p_test.cpp b/test/tests/matmul_clamp_f32_qsi8d32p_qsi4c32p_test.cpp index 2defbd45..349ac2de 100644 --- a/test/tests/matmul_clamp_f32_qsi8d32p_qsi4c32p_test.cpp +++ b/test/tests/matmul_clamp_f32_qsi8d32p_qsi4c32p_test.cpp @@ -81,7 +81,7 @@ TEST_P(MatMulTest_f32_qsi8d32p_qsi4c32p, EndToEnd) { const auto [ref_rhs_qsi4, ref_rhs_scales] = quantize_symmetric_per_block_dynamic(ref_rhs.data(), N, K, bl); - const auto ref_dst = matmul_clamp_nt_t( + const auto ref_dst = matmul_clamp_nxk( M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), nullptr, bl, ref_rhs_qsi4.data(), ref_rhs_scales.data(), nullptr, bl, nullptr, std::numeric_limits::lowest(), std::numeric_limits::max()); diff --git a/test/tests/matmul_test.cpp b/test/tests/matmul_test.cpp index 859bc733..f6d0805b 100644 --- a/test/tests/matmul_test.cpp +++ b/test/tests/matmul_test.cpp @@ -32,7 +32,7 @@ #include "test/reference/fill.hpp" #include "test/reference/pack.hpp" -// matmul_nt_nt_fp16_fp16_fp16_6x16_neon_mla +// matmul_fp16_fp16_fp16_6x16_neon_mla #include "kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.h" @@ -41,7 +41,7 @@ #include "kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme.h" -// matmul_nt_nt_fp32_fp32_fp32_2vlx2vl_sme2_mopa +// matmul_fp32_fp32_fp32_2vlx2vl_sme2_mopa #include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h" #include "kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h" @@ -57,7 +57,7 @@ namespace kai::test { /// List of supported matrix multiplication methods. static const std::array matmul_methods = { MatMulMethod{ - .name = "matmul_nt_nt_fp16_fp16_fp16_6x16_neon_mla", + .name = "matmul_kxn_fp16_fp16_fp16_6x16_neon_mla", .m0 = 6, .n0 = 16, @@ -110,7 +110,7 @@ static const std::array matmul_methods = { }, MatMulMethod{ - .name = "matmul_nt_nt_f16_f16p_f16p_2vlx2vl_sme2_mopa", + .name = "matmul_kxn_f16_f16p_f16p_2vlx2vl_sme2_mopa", .m0 = 2 * get_sme_vector_length(), .n0 = 2 * get_sme_vector_length(), @@ -169,7 +169,7 @@ static const std::array matmul_methods = { }, MatMulMethod{ - .name = "matmul_nt_nt_fp32_fp32_fp32_6x8_neon_mla", + .name = "matmul_kxn_fp32_fp32_fp32_6x8_neon_mla", .m0 = 6, .n0 = 8, @@ -222,7 +222,7 @@ static const std::array matmul_methods = { }, MatMulMethod{ - .name = "matmul_nt_nt_fp32_fp32_fp32_2vlx2vl_sme2_mopa", + .name = "matmul_kxn_fp32_fp32_fp32_2vlx2vl_sme2_mopa", .m0 = 2 * get_sme_vector_length(), .n0 = 2 * get_sme_vector_length(), -- GitLab