diff --git a/test/tests/matmul_clamp_f16_bf16p_bf16p_test.cpp b/test/tests/matmul_clamp_f16_bf16p_bf16p_test.cpp index 548484a7934b5a8a0b18bb10f9ee4ffe34672d1c..09a4ee1e9d7898385abfea26c6cee1ce7751cfd9 100644 --- a/test/tests/matmul_clamp_f16_bf16p_bf16p_test.cpp +++ b/test/tests/matmul_clamp_f16_bf16p_bf16p_test.cpp @@ -39,83 +39,82 @@ namespace kai::test { /// List of supported matrix multiplication methods. namespace { -std::array matmul_methods{}; - -struct MatMulMethodInitializer { - MatMulMethodInitializer() { - matmul_methods[0].name = "matmul_nt_nt_f16_bf16p_bf16p_8x12_neon_mla"; - matmul_methods[0].m0 = 8; - matmul_methods[0].n0 = 12; - matmul_methods[0].k0 = 4; - matmul_methods[0].dst_format = DataFormat(DataType::FP16); - matmul_methods[0].lhs_format = DataFormat(DataType::FP16); - matmul_methods[0].packed_lhs_format = - DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP16, DataType::UNKNOWN, 8, 4); - matmul_methods[0].rhs_format = DataFormat(DataType::FP16); - matmul_methods[0].packed_rhs_format = DataFormat( - DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP16, DataType::UNKNOWN, 12, 4); - matmul_methods[0].bias_format = DataFormat(DataType::FP16); - matmul_methods[0].fn_is_supported = cpu_has_bf16; - matmul_methods[0].fn_get_mr = kai_get_mr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - matmul_methods[0].fn_get_nr = kai_get_nr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - matmul_methods[0].fn_get_kr = kai_get_kr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - matmul_methods[0].fn_get_sr = kai_get_sr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - matmul_methods[0].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - matmul_methods[0].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; - matmul_methods[0].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - matmul_methods[0].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p8x4_f16_neon; - matmul_methods[0].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p8x4_f16_neon; - matmul_methods[0].fn_get_packed_lhs_offset = - kai_get_lhs_packed_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - matmul_methods[0].fn_pack_lhs = kai_run_lhs_pack_bf16p8x4_f16_neon; - matmul_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; - matmul_methods[0].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; - matmul_methods[0].fn_get_main_packed_rhs_offset = - kai_get_rhs_packed_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - matmul_methods[0].fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; - matmul_methods[0].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; - matmul_methods[0].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - matmul_methods[0].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - matmul_methods[0].fn_matmul_f16_bf16p_bf16p = kai_run_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - - matmul_methods[1].name = "matmul_nt_nt_f16_bf16p_bf16p_8x12_neon_mla_opt_bias"; - matmul_methods[1].m0 = 8; - matmul_methods[1].n0 = 12; - matmul_methods[1].k0 = 4; - matmul_methods[1].dst_format = DataFormat(DataType::FP16); - matmul_methods[1].lhs_format = DataFormat(DataType::FP16); - matmul_methods[1].packed_lhs_format = - DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP16, DataType::UNKNOWN, 8, 4); - matmul_methods[1].rhs_format = DataFormat(DataType::FP16); - matmul_methods[1].packed_rhs_format = DataFormat( - DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP16, DataType::UNKNOWN, 12, 4); - matmul_methods[1].bias_format = DataFormat(DataType::UNKNOWN); - matmul_methods[1].fn_is_supported = cpu_has_bf16; - matmul_methods[1].fn_get_mr = kai_get_mr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - matmul_methods[1].fn_get_nr = kai_get_nr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - matmul_methods[1].fn_get_kr = kai_get_kr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - matmul_methods[1].fn_get_sr = kai_get_sr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - matmul_methods[1].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - matmul_methods[1].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; - matmul_methods[1].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - matmul_methods[1].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p8x4_f16_neon; - matmul_methods[1].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p8x4_f16_neon; - matmul_methods[1].fn_get_packed_lhs_offset = - kai_get_lhs_packed_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - matmul_methods[1].fn_pack_lhs = kai_run_lhs_pack_bf16p8x4_f16_neon; - matmul_methods[1].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; - matmul_methods[1].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; - matmul_methods[1].fn_get_main_packed_rhs_offset = - kai_get_rhs_packed_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - matmul_methods[1].fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; - matmul_methods[1].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; - matmul_methods[1].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - matmul_methods[1].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - matmul_methods[1].fn_matmul_f16_bf16p_bf16p = kai_run_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - }; -}; -MatMulMethodInitializer init{}; +static const std::array& get_matmul_methods() { + static std::array matmul_methods{}; + + matmul_methods[0].name = "matmul_nt_nt_f16_bf16p_bf16p_8x12_neon_mla"; + matmul_methods[0].m0 = 8; + matmul_methods[0].n0 = 12; + matmul_methods[0].k0 = 4; + matmul_methods[0].dst_format = DataFormat(DataType::FP16); + matmul_methods[0].lhs_format = DataFormat(DataType::FP16); + matmul_methods[0].packed_lhs_format = + DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP16, DataType::UNKNOWN, 8, 4); + matmul_methods[0].rhs_format = DataFormat(DataType::FP16); + matmul_methods[0].packed_rhs_format = DataFormat( + DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP16, DataType::UNKNOWN, 12, 4); + matmul_methods[0].bias_format = DataFormat(DataType::FP16); + matmul_methods[0].fn_is_supported = cpu_has_bf16; + matmul_methods[0].fn_get_mr = kai_get_mr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[0].fn_get_nr = kai_get_nr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[0].fn_get_kr = kai_get_kr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[0].fn_get_sr = kai_get_sr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[0].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[0].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; + matmul_methods[0].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[0].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p8x4_f16_neon; + matmul_methods[0].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p8x4_f16_neon; + matmul_methods[0].fn_get_packed_lhs_offset = + kai_get_lhs_packed_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[0].fn_pack_lhs = kai_run_lhs_pack_bf16p8x4_f16_neon; + matmul_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; + matmul_methods[0].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; + matmul_methods[0].fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[0].fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; + matmul_methods[0].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; + matmul_methods[0].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[0].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[0].fn_matmul_f16_bf16p_bf16p = kai_run_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + + matmul_methods[1].name = "matmul_nt_nt_f16_bf16p_bf16p_8x12_neon_mla_opt_bias"; + matmul_methods[1].m0 = 8; + matmul_methods[1].n0 = 12; + matmul_methods[1].k0 = 4; + matmul_methods[1].dst_format = DataFormat(DataType::FP16); + matmul_methods[1].lhs_format = DataFormat(DataType::FP16); + matmul_methods[1].packed_lhs_format = + DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP16, DataType::UNKNOWN, 8, 4); + matmul_methods[1].rhs_format = DataFormat(DataType::FP16); + matmul_methods[1].packed_rhs_format = DataFormat( + DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP16, DataType::UNKNOWN, 12, 4); + matmul_methods[1].bias_format = DataFormat(DataType::UNKNOWN); + matmul_methods[1].fn_is_supported = cpu_has_bf16; + matmul_methods[1].fn_get_mr = kai_get_mr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[1].fn_get_nr = kai_get_nr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[1].fn_get_kr = kai_get_kr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[1].fn_get_sr = kai_get_sr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[1].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[1].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; + matmul_methods[1].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[1].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p8x4_f16_neon; + matmul_methods[1].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p8x4_f16_neon; + matmul_methods[1].fn_get_packed_lhs_offset = + kai_get_lhs_packed_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[1].fn_pack_lhs = kai_run_lhs_pack_bf16p8x4_f16_neon; + matmul_methods[1].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; + matmul_methods[1].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; + matmul_methods[1].fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[1].fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; + matmul_methods[1].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; + matmul_methods[1].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[1].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[1].fn_matmul_f16_bf16p_bf16p = kai_run_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + + return matmul_methods; +} } // namespace @@ -327,7 +326,7 @@ TEST_P(MatMulTestBf16OutFp16, Output) { INSTANTIATE_TEST_SUITE_P( MatMul, MatMulTestBf16OutFp16, testing::Combine( - testing::ValuesIn(matmul_methods), + testing::ValuesIn(get_matmul_methods()), testing::Values( MatMulShape{3, 7, 3}, // Smaller than block size MatMulShape{12, 8, 4}, // Same block size diff --git a/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp index a17399ef285258e534967409ed8e31a3c04c4045..be38ca16cedb5a80ff0e83ba98fc6d4fa0e661f6 100644 --- a/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp +++ b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp @@ -50,265 +50,268 @@ namespace kai::test { /// List of supported matrix multiplication methods. namespace { -std::array gemm_methods{}; -std::array gemv_methods{}; - -struct MatMulMethodInitializer { - MatMulMethodInitializer() { - gemm_methods[0].name = "matmul_nt_nt_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa"; - gemm_methods[0].m0 = 2 * get_sme_vector_length(); - gemm_methods[0].n0 = 2 * get_sme_vector_length(); - gemm_methods[0].k0 = 2; - gemm_methods[0].dst_format = DataFormat(DataType::FP32); - gemm_methods[0].lhs_format = DataFormat(DataType::FP32); - gemm_methods[0].packed_lhs_format = DataFormat( - DataType::BF16, 2 * get_sme_vector_length(), 2, DataFormat::PackFormat::NONE, DataType::FP32, - DataType::UNKNOWN, 2 * get_sme_vector_length(), 2); - gemm_methods[0].rhs_format = DataFormat(DataType::FP32); - gemm_methods[0].packed_rhs_format = DataFormat( - DataType::BF16, 2 * get_sme_vector_length(), 2, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, - DataType::UNKNOWN, 2 * get_sme_vector_length(), 2); - gemm_methods[0].bias_format = DataFormat(DataType::FP32); - gemm_methods[0].fn_is_supported = cpu_has_sme2; - gemm_methods[0].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; - gemm_methods[0].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; - gemm_methods[0].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; - gemm_methods[0].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; - gemm_methods[0].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; - gemm_methods[0].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme; - gemm_methods[0].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; - gemm_methods[0].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme; - gemm_methods[0].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme; - gemm_methods[0].fn_get_packed_lhs_offset = - kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; - gemm_methods[0].fn_pack_lhs = kai_run_lhs_pack_bf16p2vlx2_f32_sme; - gemm_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme; - gemm_methods[0].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme; - gemm_methods[0].fn_get_main_packed_rhs_offset = - kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; - gemm_methods[0].fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme; - gemm_methods[0].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme; - gemm_methods[0].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; - gemm_methods[0].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; - gemm_methods[0].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; - - gemm_methods[1].name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla"; - gemm_methods[1].m0 = 8; - gemm_methods[1].n0 = 12; - gemm_methods[1].k0 = 4; - gemm_methods[1].dst_format = DataFormat(DataType::FP32); - gemm_methods[1].lhs_format = DataFormat(DataType::FP32); - gemm_methods[1].packed_lhs_format = - DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP32, DataType::UNKNOWN, 8, 4); - gemm_methods[1].rhs_format = DataFormat(DataType::FP32); - gemm_methods[1].packed_rhs_format = DataFormat( - DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4); - gemm_methods[1].bias_format = DataFormat(DataType::FP32); - gemm_methods[1].fn_is_supported = cpu_has_bf16; - gemm_methods[1].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - gemm_methods[1].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - gemm_methods[1].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - gemm_methods[1].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - gemm_methods[1].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - gemm_methods[1].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; - gemm_methods[1].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - gemm_methods[1].fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p8x4_f32_neon; - gemm_methods[1].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p8x4_f32_neon; - gemm_methods[1].fn_get_packed_lhs_offset = - kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - gemm_methods[1].fn_pack_lhs = kai_run_lhs_quant_pack_bf16p8x4_f32_neon; - gemm_methods[1].fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; - gemm_methods[1].fn_get_packed_rhs_size_generic_block_size = - kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; - gemm_methods[1].fn_get_main_packed_rhs_offset = - kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - gemm_methods[1].fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; - gemm_methods[1].fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; - gemm_methods[1].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - gemm_methods[1].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - gemm_methods[1].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - - gemm_methods[2].name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla_f16_inputs_f32_bias_and_output"; - gemm_methods[2].m0 = 8; - gemm_methods[2].n0 = 12; - gemm_methods[2].k0 = 4; - gemm_methods[2].dst_format = DataFormat(DataType::FP32); - gemm_methods[2].lhs_format = DataFormat(DataType::FP16); - gemm_methods[2].packed_lhs_format = - DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP16, DataType::UNKNOWN, 8, 4); - gemm_methods[2].rhs_format = DataFormat(DataType::FP16); - gemm_methods[2].packed_rhs_format = DataFormat( - DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4); - gemm_methods[2].bias_format = DataFormat(DataType::FP32); - gemm_methods[2].fn_is_supported = cpu_has_bf16; - gemm_methods[2].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - gemm_methods[2].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - gemm_methods[2].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - gemm_methods[2].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - gemm_methods[2].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - gemm_methods[2].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; - gemm_methods[2].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - gemm_methods[2].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p8x4_f16_neon; - gemm_methods[2].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p8x4_f16_neon; - gemm_methods[2].fn_get_packed_lhs_offset = - kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - gemm_methods[2].fn_pack_lhs = kai_run_lhs_pack_bf16p8x4_f16_neon; - gemm_methods[2].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; - gemm_methods[2].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; - gemm_methods[2].fn_get_main_packed_rhs_offset = - kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - gemm_methods[2].fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; - gemm_methods[2].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; - gemm_methods[2].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - gemm_methods[2].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - gemm_methods[2].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - - gemm_methods[3].name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla_f16_inputs_f32_bias_and_output_opt_bias"; - gemm_methods[3].m0 = 8; - gemm_methods[3].n0 = 12; - gemm_methods[3].k0 = 4; - gemm_methods[3].dst_format = DataFormat(DataType::FP32); - gemm_methods[3].lhs_format = DataFormat(DataType::FP16); - gemm_methods[3].packed_lhs_format = - DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP16, DataType::UNKNOWN, 8, 4); - gemm_methods[3].rhs_format = DataFormat(DataType::FP16); - gemm_methods[3].packed_rhs_format = DataFormat( - DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4); - gemm_methods[3].bias_format = DataFormat(DataType::UNKNOWN); - gemm_methods[3].fn_is_supported = cpu_has_bf16; - gemm_methods[3].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - gemm_methods[3].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - gemm_methods[3].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - gemm_methods[3].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - gemm_methods[3].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - gemm_methods[3].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; - gemm_methods[3].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - gemm_methods[3].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p8x4_f16_neon; - gemm_methods[3].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p8x4_f16_neon; - gemm_methods[3].fn_get_packed_lhs_offset = - kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - gemm_methods[3].fn_pack_lhs = kai_run_lhs_pack_bf16p8x4_f16_neon; - gemm_methods[3].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; - gemm_methods[3].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; - gemm_methods[3].fn_get_main_packed_rhs_offset = - kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - gemm_methods[3].fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; - gemm_methods[3].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; - gemm_methods[3].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - gemm_methods[3].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - gemm_methods[3].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - - gemm_methods[4].name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla_opt_bias"; - gemm_methods[4].m0 = 8; - gemm_methods[4].n0 = 12; - gemm_methods[4].k0 = 4; - gemm_methods[4].dst_format = DataFormat(DataType::FP32); - gemm_methods[4].lhs_format = DataFormat(DataType::FP32); - gemm_methods[4].packed_lhs_format = - DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP32, DataType::UNKNOWN, 8, 4); - gemm_methods[4].rhs_format = DataFormat(DataType::FP32); - gemm_methods[4].packed_rhs_format = DataFormat( - DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4); - gemm_methods[4].bias_format = DataFormat(DataType::UNKNOWN); - gemm_methods[4].fn_is_supported = cpu_has_bf16; - gemm_methods[4].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - gemm_methods[4].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - gemm_methods[4].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - gemm_methods[4].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - gemm_methods[4].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - gemm_methods[4].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; - gemm_methods[4].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - gemm_methods[4].fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p8x4_f32_neon; - gemm_methods[4].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p8x4_f32_neon; - gemm_methods[4].fn_get_packed_lhs_offset = - kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - gemm_methods[4].fn_pack_lhs = kai_run_lhs_quant_pack_bf16p8x4_f32_neon; - gemm_methods[4].fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; - gemm_methods[4].fn_get_packed_rhs_size_generic_block_size = - kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; - gemm_methods[4].fn_get_main_packed_rhs_offset = - kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - gemm_methods[4].fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; - gemm_methods[4].fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; - gemm_methods[4].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - gemm_methods[4].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - gemm_methods[4].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; - - gemv_methods[0].name = "matmul_nt_nt_f32_bf16p_bf16p_1x36_neon_dot"; - gemv_methods[0].m0 = 1; - gemv_methods[0].n0 = 12; - gemv_methods[0].k0 = 4; - gemv_methods[0].dst_format = DataFormat(DataType::FP32); - gemv_methods[0].lhs_format = DataFormat(DataType::FP32); - gemv_methods[0].packed_lhs_format = - DataFormat(DataType::BF16, 1, 4, DataFormat::PackFormat::NONE, DataType::FP32, DataType::UNKNOWN, 1, 4); - gemv_methods[0].rhs_format = DataFormat(DataType::FP32); - gemv_methods[0].packed_rhs_format = DataFormat( - DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4); - gemv_methods[0].bias_format = DataFormat(DataType::FP32); - gemv_methods[0].fn_is_supported = cpu_has_bf16; - gemv_methods[0].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; - gemv_methods[0].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; - gemv_methods[0].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; - gemv_methods[0].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; - gemv_methods[0].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; - gemv_methods[0].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; - gemv_methods[0].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; - gemv_methods[0].fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p1x4_f32_neon; - gemv_methods[0].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p1x4_f32_neon; - gemv_methods[0].fn_get_packed_lhs_offset = - kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; - gemv_methods[0].fn_pack_lhs = kai_run_lhs_quant_pack_bf16p1x4_f32_neon; - gemv_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; - gemv_methods[0].fn_get_packed_rhs_size_generic_block_size = - kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; - gemv_methods[0].fn_get_main_packed_rhs_offset = - kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; - gemv_methods[0].fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; - gemv_methods[0].fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; - gemv_methods[0].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; - gemv_methods[0].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; - gemv_methods[0].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; - - gemv_methods[1].name = "matmul_nt_nt_f32_bf16p_bf16p_1x36_neon_dot_opt_bias"; - gemv_methods[1].m0 = 1; - gemv_methods[1].n0 = 12; - gemv_methods[1].k0 = 4; - gemv_methods[1].dst_format = DataFormat(DataType::FP32); - gemv_methods[1].lhs_format = DataFormat(DataType::FP32); - gemv_methods[1].packed_lhs_format = - DataFormat(DataType::BF16, 1, 4, DataFormat::PackFormat::NONE, DataType::FP32, DataType::UNKNOWN, 1, 4); - gemv_methods[1].rhs_format = DataFormat(DataType::FP32); - gemv_methods[1].packed_rhs_format = DataFormat( - DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4); - gemv_methods[1].bias_format = DataFormat(DataType::UNKNOWN); - gemv_methods[1].fn_is_supported = cpu_has_bf16; - gemv_methods[1].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; - gemv_methods[1].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; - gemv_methods[1].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; - gemv_methods[1].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; - gemv_methods[1].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; - gemv_methods[1].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; - gemv_methods[1].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; - gemv_methods[1].fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p1x4_f32_neon; - gemv_methods[1].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p1x4_f32_neon; - gemv_methods[1].fn_get_packed_lhs_offset = - kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; - gemv_methods[1].fn_pack_lhs = kai_run_lhs_quant_pack_bf16p1x4_f32_neon; - gemv_methods[1].fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; - gemv_methods[1].fn_get_packed_rhs_size_generic_block_size = - kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; - gemv_methods[1].fn_get_main_packed_rhs_offset = - kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; - gemv_methods[1].fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; - gemv_methods[1].fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; - gemv_methods[1].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; - gemv_methods[1].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; - gemv_methods[1].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; - }; -}; -MatMulMethodInitializer init{}; +static const std::array& get_gemm_methods() { + static std::array gemm_methods{}; + gemm_methods[0].name = "matmul_nt_nt_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa"; + gemm_methods[0].m0 = 2 * get_sme_vector_length(); + gemm_methods[0].n0 = 2 * get_sme_vector_length(); + gemm_methods[0].k0 = 2; + gemm_methods[0].dst_format = DataFormat(DataType::FP32); + gemm_methods[0].lhs_format = DataFormat(DataType::FP32); + gemm_methods[0].packed_lhs_format = DataFormat( + DataType::BF16, 2 * get_sme_vector_length(), 2, DataFormat::PackFormat::NONE, DataType::FP32, + DataType::UNKNOWN, 2 * get_sme_vector_length(), 2); + gemm_methods[0].rhs_format = DataFormat(DataType::FP32); + gemm_methods[0].packed_rhs_format = DataFormat( + DataType::BF16, 2 * get_sme_vector_length(), 2, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, + DataType::UNKNOWN, 2 * get_sme_vector_length(), 2); + gemm_methods[0].bias_format = DataFormat(DataType::FP32); + gemm_methods[0].fn_is_supported = cpu_has_sme2; + gemm_methods[0].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; + gemm_methods[0].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; + gemm_methods[0].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; + gemm_methods[0].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; + gemm_methods[0].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; + gemm_methods[0].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme; + gemm_methods[0].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; + gemm_methods[0].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme; + gemm_methods[0].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme; + gemm_methods[0].fn_get_packed_lhs_offset = + kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; + gemm_methods[0].fn_pack_lhs = kai_run_lhs_pack_bf16p2vlx2_f32_sme; + gemm_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme; + gemm_methods[0].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme; + gemm_methods[0].fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; + gemm_methods[0].fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme; + gemm_methods[0].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme; + gemm_methods[0].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; + gemm_methods[0].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; + gemm_methods[0].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; + + gemm_methods[1].name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla"; + gemm_methods[1].m0 = 8; + gemm_methods[1].n0 = 12; + gemm_methods[1].k0 = 4; + gemm_methods[1].dst_format = DataFormat(DataType::FP32); + gemm_methods[1].lhs_format = DataFormat(DataType::FP32); + gemm_methods[1].packed_lhs_format = + DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP32, DataType::UNKNOWN, 8, 4); + gemm_methods[1].rhs_format = DataFormat(DataType::FP32); + gemm_methods[1].packed_rhs_format = DataFormat( + DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4); + gemm_methods[1].bias_format = DataFormat(DataType::FP32); + gemm_methods[1].fn_is_supported = cpu_has_bf16; + gemm_methods[1].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[1].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[1].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[1].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[1].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[1].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemm_methods[1].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[1].fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p8x4_f32_neon; + gemm_methods[1].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p8x4_f32_neon; + gemm_methods[1].fn_get_packed_lhs_offset = + kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[1].fn_pack_lhs = kai_run_lhs_quant_pack_bf16p8x4_f32_neon; + gemm_methods[1].fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemm_methods[1].fn_get_packed_rhs_size_generic_block_size = + kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemm_methods[1].fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[1].fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemm_methods[1].fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemm_methods[1].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[1].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[1].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + + gemm_methods[2].name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla_f16_inputs_f32_bias_and_output"; + gemm_methods[2].m0 = 8; + gemm_methods[2].n0 = 12; + gemm_methods[2].k0 = 4; + gemm_methods[2].dst_format = DataFormat(DataType::FP32); + gemm_methods[2].lhs_format = DataFormat(DataType::FP16); + gemm_methods[2].packed_lhs_format = + DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP16, DataType::UNKNOWN, 8, 4); + gemm_methods[2].rhs_format = DataFormat(DataType::FP16); + gemm_methods[2].packed_rhs_format = DataFormat( + DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4); + gemm_methods[2].bias_format = DataFormat(DataType::FP32); + gemm_methods[2].fn_is_supported = cpu_has_bf16; + gemm_methods[2].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[2].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[2].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[2].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[2].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[2].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; + gemm_methods[2].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[2].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p8x4_f16_neon; + gemm_methods[2].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p8x4_f16_neon; + gemm_methods[2].fn_get_packed_lhs_offset = + kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[2].fn_pack_lhs = kai_run_lhs_pack_bf16p8x4_f16_neon; + gemm_methods[2].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; + gemm_methods[2].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; + gemm_methods[2].fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[2].fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; + gemm_methods[2].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; + gemm_methods[2].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[2].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[2].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + + gemm_methods[3].name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla_f16_inputs_f32_bias_and_output_opt_bias"; + gemm_methods[3].m0 = 8; + gemm_methods[3].n0 = 12; + gemm_methods[3].k0 = 4; + gemm_methods[3].dst_format = DataFormat(DataType::FP32); + gemm_methods[3].lhs_format = DataFormat(DataType::FP16); + gemm_methods[3].packed_lhs_format = + DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP16, DataType::UNKNOWN, 8, 4); + gemm_methods[3].rhs_format = DataFormat(DataType::FP16); + gemm_methods[3].packed_rhs_format = DataFormat( + DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4); + gemm_methods[3].bias_format = DataFormat(DataType::UNKNOWN); + gemm_methods[3].fn_is_supported = cpu_has_bf16; + gemm_methods[3].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[3].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[3].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[3].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[3].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[3].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; + gemm_methods[3].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[3].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p8x4_f16_neon; + gemm_methods[3].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p8x4_f16_neon; + gemm_methods[3].fn_get_packed_lhs_offset = + kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[3].fn_pack_lhs = kai_run_lhs_pack_bf16p8x4_f16_neon; + gemm_methods[3].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; + gemm_methods[3].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; + gemm_methods[3].fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[3].fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; + gemm_methods[3].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; + gemm_methods[3].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[3].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[3].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + + gemm_methods[4].name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla_opt_bias"; + gemm_methods[4].m0 = 8; + gemm_methods[4].n0 = 12; + gemm_methods[4].k0 = 4; + gemm_methods[4].dst_format = DataFormat(DataType::FP32); + gemm_methods[4].lhs_format = DataFormat(DataType::FP32); + gemm_methods[4].packed_lhs_format = + DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP32, DataType::UNKNOWN, 8, 4); + gemm_methods[4].rhs_format = DataFormat(DataType::FP32); + gemm_methods[4].packed_rhs_format = DataFormat( + DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4); + gemm_methods[4].bias_format = DataFormat(DataType::UNKNOWN); + gemm_methods[4].fn_is_supported = cpu_has_bf16; + gemm_methods[4].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[4].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[4].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[4].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[4].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[4].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemm_methods[4].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[4].fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p8x4_f32_neon; + gemm_methods[4].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p8x4_f32_neon; + gemm_methods[4].fn_get_packed_lhs_offset = + kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[4].fn_pack_lhs = kai_run_lhs_quant_pack_bf16p8x4_f32_neon; + gemm_methods[4].fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemm_methods[4].fn_get_packed_rhs_size_generic_block_size = + kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemm_methods[4].fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[4].fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemm_methods[4].fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemm_methods[4].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[4].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[4].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + + return gemm_methods; +} + +static const std::array& get_gemv_methods() { + static std::array gemv_methods{}; + gemv_methods[0].name = "matmul_nt_nt_f32_bf16p_bf16p_1x36_neon_dot"; + gemv_methods[0].m0 = 1; + gemv_methods[0].n0 = 12; + gemv_methods[0].k0 = 4; + gemv_methods[0].dst_format = DataFormat(DataType::FP32); + gemv_methods[0].lhs_format = DataFormat(DataType::FP32); + gemv_methods[0].packed_lhs_format = + DataFormat(DataType::BF16, 1, 4, DataFormat::PackFormat::NONE, DataType::FP32, DataType::UNKNOWN, 1, 4); + gemv_methods[0].rhs_format = DataFormat(DataType::FP32); + gemv_methods[0].packed_rhs_format = DataFormat( + DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4); + gemv_methods[0].bias_format = DataFormat(DataType::FP32); + gemv_methods[0].fn_is_supported = cpu_has_bf16; + gemv_methods[0].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[0].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[0].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[0].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[0].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[0].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemv_methods[0].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[0].fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p1x4_f32_neon; + gemv_methods[0].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p1x4_f32_neon; + gemv_methods[0].fn_get_packed_lhs_offset = + kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[0].fn_pack_lhs = kai_run_lhs_quant_pack_bf16p1x4_f32_neon; + gemv_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemv_methods[0].fn_get_packed_rhs_size_generic_block_size = + kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemv_methods[0].fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[0].fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemv_methods[0].fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemv_methods[0].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[0].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[0].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + + gemv_methods[1].name = "matmul_nt_nt_f32_bf16p_bf16p_1x36_neon_dot_opt_bias"; + gemv_methods[1].m0 = 1; + gemv_methods[1].n0 = 12; + gemv_methods[1].k0 = 4; + gemv_methods[1].dst_format = DataFormat(DataType::FP32); + gemv_methods[1].lhs_format = DataFormat(DataType::FP32); + gemv_methods[1].packed_lhs_format = + DataFormat(DataType::BF16, 1, 4, DataFormat::PackFormat::NONE, DataType::FP32, DataType::UNKNOWN, 1, 4); + gemv_methods[1].rhs_format = DataFormat(DataType::FP32); + gemv_methods[1].packed_rhs_format = DataFormat( + DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4); + gemv_methods[1].bias_format = DataFormat(DataType::UNKNOWN); + gemv_methods[1].fn_is_supported = cpu_has_bf16; + gemv_methods[1].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[1].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[1].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[1].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[1].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[1].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemv_methods[1].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[1].fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p1x4_f32_neon; + gemv_methods[1].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p1x4_f32_neon; + gemv_methods[1].fn_get_packed_lhs_offset = + kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[1].fn_pack_lhs = kai_run_lhs_quant_pack_bf16p1x4_f32_neon; + gemv_methods[1].fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemv_methods[1].fn_get_packed_rhs_size_generic_block_size = + kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemv_methods[1].fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[1].fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemv_methods[1].fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemv_methods[1].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[1].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[1].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + + return gemv_methods; +} + } // namespace /// Matrix multiplication test fixture. @@ -535,7 +538,7 @@ TEST_P(MatMulTestBf16, Output) { INSTANTIATE_TEST_SUITE_P( MatMulGemm, MatMulTestBf16, testing::Combine( - testing::ValuesIn(gemm_methods), + testing::ValuesIn(get_gemm_methods()), testing::Values( MatMulShape{1, 1, 1}, // Smallest Possible Shape MatMulShape{3, 7, 3}, // Smaller than block size @@ -560,7 +563,7 @@ INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P( MatMulGemv, MatMulTestBf16, testing::Combine( - testing::ValuesIn(gemv_methods), + testing::ValuesIn(get_gemv_methods()), testing::Values( MatMulShape{1, 1, 1}, // Smallest Possible Shape MatMulShape{1, 1, 1023}, // Long K