diff --git a/test/reference/matmul.cpp b/test/reference/matmul.cpp index 381cf1d5ebfd0b6575e9842ea11720cff86475eb..96d8f10fe0d7f20006b6955b1a9f18ef752bf66d 100644 --- a/test/reference/matmul.cpp +++ b/test/reference/matmul.cpp @@ -188,7 +188,7 @@ std::vector matmul( template < typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale, typename RhsZeroPoint, typename Bias, typename IntAcc, typename DstData> -std::vector matmul_clamp_nt_t( +std::vector matmul_clamp_nxk( size_t m, size_t n, size_t k, // const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // @@ -241,22 +241,21 @@ std::vector matmul_clamp_nt_t( return dst; } -template std::vector matmul_clamp_nt_t( +template std::vector matmul_clamp_nxk( size_t m, size_t n, size_t k, // const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // const void* biases, // float min_value, float max_value); -template std::vector -matmul_clamp_nt_t( +template std::vector matmul_clamp_nxk( size_t m, size_t n, size_t k, // const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // const void* biases, // float min_value, float max_value); -template std::vector matmul_clamp_nt_t( +template std::vector matmul_clamp_nxk( size_t m, size_t n, size_t k, // const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // @@ -266,7 +265,7 @@ template std::vector matmul_clamp_nt_t -std::vector matmul_clamp_nt_nt( +std::vector matmul_clamp_kxn( size_t m, size_t n, size_t k, // const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // @@ -319,23 +318,21 @@ std::vector matmul_clamp_nt_nt( return dst; } -template std::vector matmul_clamp_nt_nt( +template std::vector matmul_clamp_kxn( size_t m, size_t n, size_t k, // const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // const void* biases, // float min_value, float max_value); -template std::vector -matmul_clamp_nt_nt( +template std::vector matmul_clamp_kxn( size_t m, size_t n, size_t k, // const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // const void* biases, // float min_value, float max_value); -template std::vector -matmul_clamp_nt_nt( +template std::vector matmul_clamp_kxn( size_t m, size_t n, size_t k, // const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // diff --git a/test/reference/matmul.hpp b/test/reference/matmul.hpp index 88a0729ff662f67b25321034a6981818c13edfe7..0e7348f1ff0e6c3b9534c367e138b1e60e8245fb 100644 --- a/test/reference/matmul.hpp +++ b/test/reference/matmul.hpp @@ -67,7 +67,7 @@ std::vector matmul( /// Matrix multiplication with quantized input and floating-point output. /// -/// The LHS matrix is non-transposed and the RHS matrix is transposed. +/// The RHS matrix is transposed. /// /// @tparam LhsData The data type of the LHS matrix. /// @tparam LhsScale The data type of the quantization scales of the LHS matrix. @@ -98,7 +98,7 @@ std::vector matmul( template < typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale, typename RhsZeroPoint, typename Bias, typename IntAcc, typename DstData> -std::vector matmul_clamp_nt_t( +std::vector matmul_clamp_nxk( size_t m, size_t n, size_t k, // const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // @@ -107,7 +107,7 @@ std::vector matmul_clamp_nt_t( /// Matrix multiplication with quantized input and floating-point output. /// -/// The LHS matrix is non-transposed and the RHS matrix is non-transposed. +/// The RHS matrix is non-transposed. /// /// @tparam LhsData The data type of the LHS matrix. /// @tparam LhsScale The data type of the quantization scales of the LHS matrix. @@ -138,7 +138,7 @@ std::vector matmul_clamp_nt_t( template < typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale, typename RhsZeroPoint, typename Bias, typename IntAcc, typename DstData> -std::vector matmul_clamp_nt_nt( +std::vector matmul_clamp_kxn( size_t m, size_t n, size_t k, // const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // diff --git a/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp index 7d2693d4b85eed882ff1fa7d2faa6abe26016303..5690de9ce2ac60faf8e21c84d1c6354ae9f60e6f 100644 --- a/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp +++ b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp @@ -43,7 +43,7 @@ namespace kai::test { namespace { const std::array gemm_methods = { MatMulMethod{ - .name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla", + .name = "matmul_kxn_f32_bf16p_bf16p_8x12_neon_mla", .m0 = 8, .n0 = 12, @@ -89,7 +89,7 @@ const std::array gemm_methods = { .fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, }, MatMulMethod{ - .name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla_opt_bias", + .name = "matmul_kxn_f32_bf16p_bf16p_8x12_neon_mla_opt_bias", .m0 = 8, .n0 = 12, @@ -137,7 +137,7 @@ const std::array gemm_methods = { const std::array gemv_methods = { MatMulMethod{ - .name = "matmul_nt_nt_f32_bf16p_bf16p_1x36_neon_dot", + .name = "matmul_kxn_f32_bf16p_bf16p_1x36_neon_dot", .m0 = 1, .n0 = 12, @@ -183,7 +183,7 @@ const std::array gemv_methods = { .fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, }, MatMulMethod{ - .name = "matmul_nt_nt_f32_bf16p_bf16p_1x36_neon_dot_opt_bias", + .name = "matmul_kxn_f32_bf16p_bf16p_1x36_neon_dot_opt_bias", .m0 = 1, .n0 = 12, diff --git a/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp b/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp index 0155d7d80701810f2c9ee523e21100cfd228eb41..2fc02d5025f3375645adc669322a51ac82f72dc1 100644 --- a/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp +++ b/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp @@ -89,7 +89,7 @@ TEST_P(MatMulTest_f32_qmatmul_clamp_f32_qai8dxp_qsi4c32p, EndToEnd_RHS_nxk) { const auto [ref_rhs_qsi4, ref_rhs_scales] = quantize_symmetric_per_block_dynamic(ref_rhs.data(), N, K, bl); - const auto ref_dst = matmul_clamp_nt_t( + const auto ref_dst = matmul_clamp_nxk( M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), K, ref_rhs_qsi4.data(), ref_rhs_scales.data(), nullptr, bl, nullptr, std::numeric_limits::lowest(), std::numeric_limits::max()); @@ -186,7 +186,7 @@ TEST_P(MatMulTest_f32_qmatmul_clamp_f32_qai8dxp_qsi4c32p, EndToEnd_RHS_kxn) { ref_rhs_qsi4_transposed.data(), N, K, ref_rhs_qsi4_nxk_stride, ref_rhs_qsi4_kxn_stride, ref_rhs_qsi4_kxn_size_bytes); - const auto ref_dst = matmul_clamp_nt_nt( + const auto ref_dst = matmul_clamp_kxn( M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), K, ref_rhs_qsi4.data(), ref_rhs_scales.data(), nullptr, bl, nullptr, std::numeric_limits::lowest(), std::numeric_limits::max()); diff --git a/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp b/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp index b245e7f55174a0dec670d6d2daba6f15661470f4..2c814bb695ff300c6583711b703d44efdd24e83b 100644 --- a/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp +++ b/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp @@ -86,7 +86,7 @@ TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_nxk_qsi4cx) { const auto [ref_rhs_qsi4, ref_rhs_scales] = quantize_symmetric_per_block_dynamic(ref_rhs.data(), N, K, K); - const auto ref_dst = matmul_clamp_nt_t( + const auto ref_dst = matmul_clamp_nxk( M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), K, ref_rhs_qsi4.data(), ref_rhs_scales.data(), nullptr, K, ref_biases.data(), std::numeric_limits::lowest(), std::numeric_limits::max()); @@ -165,7 +165,7 @@ TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_nxk_qsu4cx) { const auto [ref_rhs_qsi4, ref_rhs_scales] = quantize_symmetric_per_block_dynamic(ref_rhs.data(), N, K, K); - const auto ref_dst = matmul_clamp_nt_t( + const auto ref_dst = matmul_clamp_nxk( M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), K, ref_rhs_qsi4.data(), ref_rhs_scales.data(), nullptr, K, ref_biases.data(), std::numeric_limits::lowest(), std::numeric_limits::max()); @@ -257,7 +257,7 @@ TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_kxn_qsi4cx) { ref_rhs_qsi4_transposed.data(), N, K, ref_rhs_qsi4_nxk_stride, ref_rhs_qsi4_kxn_stride, ref_rhs_qsi4_kxn_size_bytes); - const auto ref_dst = matmul_clamp_nt_nt( + const auto ref_dst = matmul_clamp_kxn( M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), K, ref_rhs_qsi4.data(), ref_rhs_scales.data(), nullptr, K, ref_biases.data(), std::numeric_limits::lowest(), std::numeric_limits::max()); @@ -346,7 +346,7 @@ TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_kxn_qsu4cx) { ref_rhs_qsi4_transposed.data(), N, K, ref_rhs_qsi4_nxk_stride, ref_rhs_qsi4_kxn_stride, ref_rhs_qsi4_kxn_size_bytes); - const auto ref_dst = matmul_clamp_nt_nt( + const auto ref_dst = matmul_clamp_kxn( M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), K, ref_rhs_qsi4.data(), ref_rhs_scales.data(), nullptr, K, ref_biases.data(), std::numeric_limits::lowest(), std::numeric_limits::max()); diff --git a/test/tests/matmul_clamp_f32_qsi8d32p_qsi4c32p_test.cpp b/test/tests/matmul_clamp_f32_qsi8d32p_qsi4c32p_test.cpp index 2defbd45110147b8c4fefc2e58c59e9afb9f578c..349ac2defc8d88d730b3a14e0f7c3a92feba95e6 100644 --- a/test/tests/matmul_clamp_f32_qsi8d32p_qsi4c32p_test.cpp +++ b/test/tests/matmul_clamp_f32_qsi8d32p_qsi4c32p_test.cpp @@ -81,7 +81,7 @@ TEST_P(MatMulTest_f32_qsi8d32p_qsi4c32p, EndToEnd) { const auto [ref_rhs_qsi4, ref_rhs_scales] = quantize_symmetric_per_block_dynamic(ref_rhs.data(), N, K, bl); - const auto ref_dst = matmul_clamp_nt_t( + const auto ref_dst = matmul_clamp_nxk( M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), nullptr, bl, ref_rhs_qsi4.data(), ref_rhs_scales.data(), nullptr, bl, nullptr, std::numeric_limits::lowest(), std::numeric_limits::max()); diff --git a/test/tests/matmul_test.cpp b/test/tests/matmul_test.cpp index 859bc733acc022a9c66e128ab4def11b33205962..f6d0805be975973fefde2225920392449c33bc52 100644 --- a/test/tests/matmul_test.cpp +++ b/test/tests/matmul_test.cpp @@ -32,7 +32,7 @@ #include "test/reference/fill.hpp" #include "test/reference/pack.hpp" -// matmul_nt_nt_fp16_fp16_fp16_6x16_neon_mla +// matmul_fp16_fp16_fp16_6x16_neon_mla #include "kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.h" @@ -41,7 +41,7 @@ #include "kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme.h" -// matmul_nt_nt_fp32_fp32_fp32_2vlx2vl_sme2_mopa +// matmul_fp32_fp32_fp32_2vlx2vl_sme2_mopa #include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h" #include "kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h" @@ -57,7 +57,7 @@ namespace kai::test { /// List of supported matrix multiplication methods. static const std::array matmul_methods = { MatMulMethod{ - .name = "matmul_nt_nt_fp16_fp16_fp16_6x16_neon_mla", + .name = "matmul_kxn_fp16_fp16_fp16_6x16_neon_mla", .m0 = 6, .n0 = 16, @@ -110,7 +110,7 @@ static const std::array matmul_methods = { }, MatMulMethod{ - .name = "matmul_nt_nt_f16_f16p_f16p_2vlx2vl_sme2_mopa", + .name = "matmul_kxn_f16_f16p_f16p_2vlx2vl_sme2_mopa", .m0 = 2 * get_sme_vector_length(), .n0 = 2 * get_sme_vector_length(), @@ -169,7 +169,7 @@ static const std::array matmul_methods = { }, MatMulMethod{ - .name = "matmul_nt_nt_fp32_fp32_fp32_6x8_neon_mla", + .name = "matmul_kxn_fp32_fp32_fp32_6x8_neon_mla", .m0 = 6, .n0 = 8, @@ -222,7 +222,7 @@ static const std::array matmul_methods = { }, MatMulMethod{ - .name = "matmul_nt_nt_fp32_fp32_fp32_2vlx2vl_sme2_mopa", + .name = "matmul_kxn_fp32_fp32_fp32_2vlx2vl_sme2_mopa", .m0 = 2 * get_sme_vector_length(), .n0 = 2 * get_sme_vector_length(),