diff --git a/CMakeLists.txt b/CMakeLists.txt index 404a54d3b8090a4dba034749c05f0e389ceb8212..f7853fd3219f54c825199fa2166cabf9fa2975b5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -88,57 +88,57 @@ set(KLEIDIAI_FILES_SCALAR kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0.c - kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.c - kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c + kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c + kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.c ) set(KLEIDIAI_FILES_NEON_FP16 - kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.c kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.c + kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.c ) set(KLEIDIAI_FILES_NEON_BF16 + kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot.c + kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.c kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p1x4_f32_neon.c kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p8x4_f32_neon.c kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon.c - kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot.c - kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.c ) set(KLEIDIAI_FILES_NEON_FP16_BF16 + kai/ukernels/matmul/matmul_clamp_f16_bf16p_bf16p/kai_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla.c kai/ukernels/matmul/pack/kai_lhs_pack_bf16p8x4_f16_neon.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p12x4biasf16_f16_neon.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p12x4biasf32_f16_neon.c - kai/ukernels/matmul/matmul_clamp_f16_bf16p_bf16p/kai_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla.c ) set(KLEIDIAI_FILES_NEON - kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.c - kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon.c - kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.c kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.c kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla_asm.S - kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c + kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.c + kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon.c + kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon.c + kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.c ) set(KLEIDIAI_FILES_NEON_DOTPROD_ASM + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod_asm.S - kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod_asm.S - kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8_neon_dotprod_asm.S - kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod_asm.S - kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod_asm.S - kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p8x4_4x8_neon_dotprod_asm.S + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p8x4_1x8_neon_dotprod.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p8x4_1x8_neon_dotprod_asm.S - kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod_asm.S kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8_neon_dotprod.c + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8_neon_dotprod_asm.S kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod_asm.S kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod.c + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod_asm.S kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p8x4_4x8_neon_dotprod.c - kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p8x4_1x8_neon_dotprod.c + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p8x4_4x8_neon_dotprod_asm.S ) set(KLEIDIAI_FILES_NEON_DOTPROD @@ -146,25 +146,25 @@ set(KLEIDIAI_FILES_NEON_DOTPROD kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.c - kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x4_qsi4cxp8x4_8x8x32_neon_dotprod.c - kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c - kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c - kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.c + kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c + kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c + kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c ) set(KLEIDIAI_FILES_NEON_I8MM_ASM - kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm_asm.S - kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm_asm.S - kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm_asm.S - kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8_neon_i8mm_asm.S kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.c - kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8_neon_i8mm.c + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm_asm.S kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm_asm.S + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8_neon_i8mm.c + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8_neon_i8mm_asm.S kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm_asm.S ) set(KLEIDIAI_FILES_NEON_I8MM @@ -173,36 +173,36 @@ set(KLEIDIAI_FILES_NEON_I8MM kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.c - kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c - kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.c + kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.c + kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c ) set(KLEIDIAI_FILES_SME kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme.c + kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.c + kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme.c - kai/ukernels/matmul/pack/kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.c - kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.c - kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c + kai/ukernels/matmul/pack/kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme.c ) set(KLEIDIAI_FILES_SME2 + kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot.c kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.c kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.c kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.c - kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot.c - kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c - kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.c - kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c - kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot.c + kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c + kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c + kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.c + kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c ) @@ -263,35 +263,35 @@ if(KLEIDIAI_BUILD_TESTS) include(GoogleTest) add_library(kleidiai_test_framework - test/common/data_type.cpp + test/common/bfloat16.cpp + test/common/bfloat16_asm.S + test/common/compare.cpp + test/common/cpu_info.cpp test/common/data_format.cpp - test/common/printer.cpp + test/common/data_type.cpp + test/common/float16.cpp + test/common/float16_asm.S test/common/int4.cpp - test/common/compare.cpp test/common/matmul_test_common.cpp test/common/matrix_portion.cpp + test/common/printer.cpp test/common/rect.cpp test/common/round.cpp test/common/round_asm.S - test/common/bfloat16.cpp - test/common/bfloat16_asm.S - test/common/float16.cpp - test/common/float16_asm.S - test/common/cpu_info.cpp $<$>:test/common/sme.cpp> test/reference/binary_elementwise.cpp + test/reference/cast.cpp + test/reference/clamp.cpp + test/reference/fill.cpp test/reference/matmul.cpp test/reference/matmul_pack.cpp - test/reference/fill.cpp test/reference/pack.cpp test/reference/pad.cpp - test/reference/clamp.cpp test/reference/quantize.cpp test/reference/reduce.cpp - test/reference/transpose.cpp - test/reference/cast.cpp test/reference/reorder.cpp + test/reference/transpose.cpp ) target_compile_options(kleidiai_test_framework @@ -321,15 +321,15 @@ if(KLEIDIAI_BUILD_TESTS) add_executable(kleidiai_test test/tests/bfloat16_test.cpp test/tests/float16_test.cpp - test/tests/matmul_test.cpp + test/tests/matmul_clamp_f16_bf16p_bf16p_test.cpp + test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp test/tests/matmul_clamp_f32_f32_f32p_test.cpp + test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp test/tests/matmul_clamp_f32_qai8dxp_qsi8cxp_test.cpp test/tests/matmul_clamp_f32_qsi8d32p_qsi4c32p_test.cpp - test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp - test/tests/matmul_clamp_f16_bf16p_bf16p_test.cpp - test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp + test/tests/matmul_test.cpp ) endif() diff --git a/test/reference/binary_elementwise.cpp b/test/reference/binary_elementwise.cpp index f65136572c6689a919eea05359e3e539a4f9a2ed..803d87fb6c8373dff53c3c468acd8a071ab87042 100644 --- a/test/reference/binary_elementwise.cpp +++ b/test/reference/binary_elementwise.cpp @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -72,9 +72,8 @@ std::vector binary_elementwise_any_op_type( const auto height = std::max(lhs_height, rhs_height); const auto width = std::max(lhs_width, rhs_width); - std::vector dst; - dst.resize(height * width * size_in_bits / 8); KAI_ASSUME(width * size_in_bits % 8 == 0); + std::vector dst(height * width * size_in_bits / 8); for (size_t y = 0; y < height; ++y) { for (size_t x = 0; x < width; ++x) { diff --git a/test/reference/matmul.cpp b/test/reference/matmul.cpp index 008ec320fe106e5016b7d5a71764d73ac90ef50c..a5fd1e66b53f8cbbf3cddd86c5fa3a39289b91bf 100644 --- a/test/reference/matmul.cpp +++ b/test/reference/matmul.cpp @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -189,24 +189,25 @@ template < typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale, typename RhsZeroPoint, typename BiasData, typename BiasScale, typename BiasZeroPoint, typename DstData> std::vector matmul_nt_t_quantized( - size_t m, size_t n, size_t k, // - const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_height, - size_t lhs_quant_width, // - const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height, - size_t rhs_quant_width, // - const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width) { + size_t m, size_t n, size_t k, // + const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, // + size_t lhs_quant_height, size_t lhs_quant_width, // + const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, // + size_t rhs_quant_height, size_t rhs_quant_width, // + const void* bias_data, const void* bias_scales, const void* bias_zero_points, // + size_t bias_quant_width) { const auto lhs_num_quant_per_row = round_up_division(k, lhs_quant_width); const auto rhs_num_quant_per_row = round_up_division(k, rhs_quant_width); std::vector dst(m * n * sizeof(DstData)); - for (size_t y = 0; y < m; ++y) { - for (size_t x = 0; x < n; ++x) { + for (size_t row = 0; row < m; ++row) { + for (size_t col = 0; col < n; ++col) { DstData acc = 0; for (size_t i = 0; i < k; ++i) { - const auto lhs_data_index = y * k + i; - const auto lhs_quant_index = y / lhs_quant_height * lhs_num_quant_per_row + i / lhs_quant_width; + const auto lhs_data_index = row * k + i; + const auto lhs_quant_index = row / lhs_quant_height * lhs_num_quant_per_row + i / lhs_quant_width; const auto lhs_value = read_array(lhs_data, lhs_data_index); const auto lhs_scale = lhs_scales != nullptr ? read_array(lhs_scales, lhs_quant_index) : static_cast(1); @@ -214,8 +215,8 @@ std::vector matmul_nt_t_quantized( ? read_array(lhs_zero_points, lhs_quant_index) : static_cast(0); - const auto rhs_data_index = x * k + i; - const auto rhs_quant_index = x / rhs_quant_height * rhs_num_quant_per_row + i / rhs_quant_width; + const auto rhs_data_index = col * k + i; + const auto rhs_quant_index = col / rhs_quant_height * rhs_num_quant_per_row + i / rhs_quant_width; const auto rhs_value = read_array(rhs_data, rhs_data_index); const auto rhs_scale = rhs_scales != nullptr ? read_array(rhs_scales, rhs_quant_index) : static_cast(1); @@ -230,19 +231,19 @@ std::vector matmul_nt_t_quantized( } if (bias_data != nullptr) { - const auto bias_value = read_array(bias_data, x); + const auto bias_value = read_array(bias_data, col); const auto bias_scale = bias_scales != nullptr - ? read_array(bias_scales, x / bias_quant_width) + ? read_array(bias_scales, col / bias_quant_width) : static_cast(1); const auto bias_zero_point = bias_zero_points != nullptr - ? read_array(bias_zero_points, x / bias_quant_width) + ? read_array(bias_zero_points, col / bias_quant_width) : static_cast(0); acc += (static_cast(bias_value) - static_cast(bias_zero_point)) * static_cast(bias_scale); } - write_array(dst.data(), y * n + x, acc); + write_array(dst.data(), row * n + col, acc); } } diff --git a/test/reference/matmul_pack.cpp b/test/reference/matmul_pack.cpp index cd225b83efe13864edc5328742a431ea235758f1..55e916a0d45cf2b031fff0071def1f523ab3a0ed 100644 --- a/test/reference/matmul_pack.cpp +++ b/test/reference/matmul_pack.cpp @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -8,7 +8,6 @@ #include #include -#include #include #include "test/common/round.hpp" diff --git a/test/reference/quantize.cpp b/test/reference/quantize.cpp index 82a09fd7614db1b9f25a927414f7a49ccc177c56..67cc36045b191f437924b7d058d0f2c536023820 100644 --- a/test/reference/quantize.cpp +++ b/test/reference/quantize.cpp @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -266,10 +266,12 @@ std::vector quantize_asymmetric_per_block( template std::tuple, std::vector, std::vector> quantize_asymmetric_per_block_dynamic( const void* src, size_t height, size_t width, size_t quant_width) { + /* Calculate the asymmetric quantization information, one scaling per row */ auto [scales_src_type, zero_points] = compute_asymmetric_per_block_quantization_info( src, height, width, quant_width); + /* Do the actual quantization */ auto data = quantize_asymmetric_per_block( src, scales_src_type.data(), zero_points.data(), height, width, quant_width); diff --git a/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp b/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp index 1acd19ed351828ef1bc522b793e2bb7ce723315b..2ffe5b7070d6ecfd116618eefceb8a8fe362ffa2 100644 --- a/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp +++ b/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp @@ -540,9 +540,11 @@ INSTANTIATE_TEST_SUITE_P( // clang-format off MatMulShape{ 1, 1, 1}, MatMulShape{ 1, 49, 21}, + MatMulShape{ 16, 16, 4}, MatMulShape{ 20, 30, 40}, MatMulShape{ 23, 1, 43}, MatMulShape{ 32, 14, 1}, + MatMulShape{ 32, 32, 4}, MatMulShape{ 64, 64, 4}, MatMulShape{123, 85, 45}, MatMulShape{130, 130, 6}, @@ -572,6 +574,7 @@ INSTANTIATE_TEST_SUITE_P( MatMulShape{1, 16, 4}, MatMulShape{1, 16, 16}, MatMulShape{1, 17, 4}, + MatMulShape{1, 32, 4}, MatMulShape{1, 32, 32}, MatMulShape{1, 33, 200}, MatMulShape{1, 64, 4},