diff --git a/CMakeLists.txt b/CMakeLists.txt index 40ab01af72f934801ef18f2619c59efd96d9279f..80a16331af441fc87b2b7cb91d80edd59fb17566 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -56,22 +56,42 @@ set(KLEIDIAI_WARNING_FLAGS $<$:${KLEIDIAI_WARNING_FLAGS_CXX}> ) -add_library(kleidiai) - -target_sources(kleidiai PRIVATE +set(KLEIDIAI_FILES_NEON src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c src/matmul/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c +) + +set(KLEIDIAI_FILES_NEON_DOTPROD src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.c +) + +set(KLEIDIAI_FILES_NEON_I8MM src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.c src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.c - src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.c) + src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.c +) + +add_library(kleidiai + ${KLEIDIAI_FILES_NEON} + ${KLEIDIAI_FILES_NEON_DOTPROD} + ${KLEIDIAI_FILES_NEON_I8MM} +) target_include_directories(kleidiai - PRIVATE src/ + PUBLIC src + PRIVATE . ) +foreach(KLEIDIAI_SOURCE_FILE IN LISTS KLEIDIAI_FILES_NEON_DOTPROD) + set_property(SOURCE ${KLEIDIAI_SOURCE_FILE} PROPERTY COMPILE_OPTIONS -march=armv8.2-a+dotprod) +endforeach() + +foreach(KLEIDIAI_SOURCE_FILE IN LISTS KLEIDIAI_FILES_NEON_I8MM) + set_property(SOURCE ${KLEIDIAI_SOURCE_FILE} PROPERTY COMPILE_OPTIONS -march=armv8.2-a+i8mm) +endforeach() + target_compile_options(kleidiai PRIVATE ${KLEIDIAI_WARNING_FLAGS} ) diff --git a/src/kai_common.h b/src/kai_common.h index 9bb2550af70a53a66df3b28692d8d674b10b8eba..f7c205c37464de89531caa1ff873fe50307bb242 100644 --- a/src/kai_common.h +++ b/src/kai_common.h @@ -43,8 +43,6 @@ extern "C" { #define KAI_ASSUME_IF_MSG KAI_ASSERT_IF_MSG #define KAI_ASSUME_IF KAI_ASSERT_IF -#define KAI_UNUSED(x) (void)(x) - #define KAI_UNUSED(x) (void)(x) #define KAI_MIN(a, b) (((a) < (b)) ? (a) : (b)) #define KAI_MAX(a, b) (((a) > (b)) ? (a) : (b)) diff --git a/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c b/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c index d93469a0e71ae8f8875fb0fdfe428f505c8990ce..6e618ccc08f130824ce0cbc31d51f97fc12487ac 100644 --- a/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c +++ b/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c @@ -71,12 +71,13 @@ void kai_run_lhs_quant_pack_qai8dxp_f32( float max0 = -FLT_MAX; float min0 = FLT_MAX; - float32x4_t vmax0 = vdupq_n_f32(-FLT_MAX); - float32x4_t vmin0 = vdupq_n_f32(FLT_MAX); - // Find min/max for each channel int32_t k_idx = 0; + #if defined(__aarch64__) + float32x4_t vmax0 = vdupq_n_f32(-FLT_MAX); + float32x4_t vmin0 = vdupq_n_f32(FLT_MAX); + for (; k_idx <= ((int32_t)k - 8); k_idx += 8) { const float32x4_t src0_0 = vld1q_f32(src_ptr + 0 + (size_t)k_idx); const float32x4_t src0_1 = vld1q_f32(src_ptr + 4 + (size_t)k_idx); @@ -135,7 +136,7 @@ void kai_run_lhs_quant_pack_qai8dxp_f32( for (; k_idx < (int32_t)k_internal; k_idx += k_block_len) { for (size_t k_block_idx = 0; k_block_idx < (size_t)k_block_len; ++k_block_idx) { // Clamp at the last valid k-index - const size_t k_idx_start = KAI_MIN((size_t)k_idx + k_block_idx, k); + const size_t k_idx_start = KAI_MIN((size_t)k_idx + k_block_idx, k - 1); const float src0_0 = *(src_ptr + k_idx_start); diff --git a/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.h b/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.h index 85715dc661f90c5e454f4c9d39d0345b2eb14fa5..28bd1ac1b0129a36992ebfa1acbbf68b9867d34a 100644 --- a/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.h +++ b/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.h @@ -5,8 +5,8 @@ // #pragma once +#include #include -#include #ifdef __cplusplus extern "C" { diff --git a/src/matmul/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c b/src/matmul/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c index 321ae08c58d5cdd97ea95afaf49df739d886ed07..aed8a5f427b34f09a90805a0c8b4c62245a3da15 100644 --- a/src/matmul/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c +++ b/src/matmul/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c @@ -92,8 +92,10 @@ void kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( const size_t k_idx_start0 = (x / 2) + kr_idx / 2 + s * (kr / sr) / 2; const size_t k_idx_start1 = k_idx_start0 + (kr / 2); - const size_t src_addr_byte0 = i * rhs_stride + k_idx_start0; - const size_t src_addr_byte1 = i * rhs_stride + k_idx_start1; + // Clamp the row index to avoid out-of-bound reads + const size_t src_row_idx = y + i >= n ? 0 : i; + const size_t src_addr_byte0 = src_row_idx * rhs_stride + k_idx_start0; + const size_t src_addr_byte1 = src_row_idx * rhs_stride + k_idx_start1; uint8_t byte0 = rhs_zero_point | rhs_zero_point << 4; uint8_t byte1 = rhs_zero_point | rhs_zero_point << 4; @@ -135,7 +137,9 @@ void kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( // Adjust the scales for (size_t i = 0; i < nr; ++i) { - *((float*)(dst_row)) = scale[y + i] * 0.0625F; + // Clamp the row index to avoid out-of-bound reads + const size_t src_row_idx = KAI_MIN(y + i, n - 1); + *((float*)(dst_row)) = scale[src_row_idx] * 0.0625F; dst_row += sizeof(float); } } diff --git a/src/matmul/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h b/src/matmul/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h index 8340c2bcc8f0cae464426a07e9a27a69d752e23f..1ce70c8bcc66c0f68f6c70f8674acd10c0c9fd00 100644 --- a/src/matmul/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h +++ b/src/matmul/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h @@ -5,8 +5,8 @@ // #pragma once +#include #include -#include #ifdef __cplusplus extern "C" { diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c index 6fee9b51c7c5071524b5a02f33defb555919f7b5..4d47ef1e947dd947d05aa31741dd67412d27ec3f 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c @@ -3,6 +3,9 @@ // // SPDX-License-Identifier: Apache-2.0 // +#if !defined(__ARM_FEATURE_DOTPROD) +#error "Dotprod extension required to compile this micro-kernel" +#else #include "kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h" #include @@ -95,7 +98,6 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotpr void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod( size_t m, size_t n, size_t k, const void* restrict lhs_packed, const void* restrict rhs_packed, float* restrict dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { -#if defined(__ARM_FEATURE_DOTPROD) KAI_ASSERT(dst_stride_col == sizeof(float)); if (m == 0) { @@ -214,17 +216,5 @@ void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod( } lhs_ptr_start += lhs_packed_stride; } -#else - KAI_ASSERT(false); - KAI_UNUSED(m); - KAI_UNUSED(n); - KAI_UNUSED(k); - KAI_UNUSED(lhs_packed); - KAI_UNUSED(rhs_packed); - KAI_UNUSED(dst); - KAI_UNUSED(dst_stride_row); - KAI_UNUSED(dst_stride_col); - KAI_UNUSED(scalar_min); - KAI_UNUSED(scalar_max); -#endif } +#endif // Architectural feature check diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h index 558f0f4c2b8674c594412167803fe834adfe7b3f..7334e72f57653ae063c7e3daea8121c878f9699e 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h @@ -9,7 +9,6 @@ #include #endif #include -#include #ifdef __cplusplus extern "C" { diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.c b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.c index f243f3d7244524d4cd843fbf57d5e9631ce6f27c..6bff095a3d2babb95df7923f3fd1de076c67ca3d 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.c +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.c @@ -3,6 +3,9 @@ // // SPDX-License-Identifier: Apache-2.0 // +#if !defined(__ARM_FEATURE_DOTPROD) +#error "Dotprod extension required to compile this micro-kernel" +#else #include "kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h" #include @@ -95,7 +98,6 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotpr void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod( size_t m, size_t n, size_t k, const void* restrict lhs_packed, const void* restrict rhs_packed, float* restrict dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { -#if defined(__ARM_FEATURE_DOTPROD) KAI_ASSERT(dst_stride_col == sizeof(float)); if (m == 0) { @@ -266,17 +268,5 @@ void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod( } lhs_ptr_start += lhs_packed_stride; } -#else - KAI_ASSERT(false); - KAI_UNUSED(m); - KAI_UNUSED(n); - KAI_UNUSED(k); - KAI_UNUSED(lhs_packed); - KAI_UNUSED(rhs_packed); - KAI_UNUSED(dst); - KAI_UNUSED(dst_stride_row); - KAI_UNUSED(dst_stride_col); - KAI_UNUSED(scalar_min); - KAI_UNUSED(scalar_max); -#endif } +#endif // Architectural feature check diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h index 69bf5c08f35b5adf531b83e5ac262910c83159cd..69a16a3b018fca789053e22a32f9ba0f46839902 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h @@ -9,7 +9,6 @@ #include #endif #include -#include #ifdef __cplusplus extern "C" { diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.c b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.c index 0bdd4f9e4501f0dc5ddca884edd75f0bdf77e87c..45f431707eff835a24304da83a01e4422ff80609 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.c +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.c @@ -3,6 +3,9 @@ // // SPDX-License-Identifier: Apache-2.0 // +#if !defined(__ARM_FEATURE_MATMUL_INT8) +#error "I8mm extension required to compile this micro-kernel" +#else #include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.h" #include @@ -95,7 +98,6 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm( void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm( size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { -#if defined(__ARM_FEATURE_MATMUL_INT8) KAI_ASSERT(dst_stride_col == sizeof(float)); if (m == 0) { @@ -263,17 +265,5 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm( : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"); -#else - KAI_ASSERT(false); - KAI_UNUSED(m); - KAI_UNUSED(n); - KAI_UNUSED(k); - KAI_UNUSED(lhs_packed); - KAI_UNUSED(rhs_packed); - KAI_UNUSED(dst); - KAI_UNUSED(dst_stride_row); - KAI_UNUSED(dst_stride_col); - KAI_UNUSED(scalar_min); - KAI_UNUSED(scalar_max); -#endif } +#endif // Architectural feature check diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.h b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.h index 782fa28ed3207c8318fa057b57cafc55fa969b96..8592d28516a4e9d4131df137c697a3c1a8154f5c 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.h +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.h @@ -9,7 +9,6 @@ #include #endif #include -#include #ifdef __cplusplus extern "C" { diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c index 2c348c3741e74b7c1aa3f22593058ce3b4a1a053..293fad2907affa2d79e7e68e9a4c165bb05d00de 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c @@ -3,6 +3,9 @@ // // SPDX-License-Identifier: Apache-2.0 // +#if !defined(__ARM_FEATURE_MATMUL_INT8) +#error "I8mm extension required to compile this micro-kernel" +#else #include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h" #include @@ -95,7 +98,6 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm( void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm( size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { -#if defined(__ARM_FEATURE_MATMUL_INT8) KAI_ASSERT(dst_stride_col == sizeof(float)); if (m == 0) { @@ -481,17 +483,5 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm( : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x11", "x12", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"); -#else - KAI_ASSERT(false); - KAI_UNUSED(m); - KAI_UNUSED(n); - KAI_UNUSED(k); - KAI_UNUSED(lhs_packed); - KAI_UNUSED(rhs_packed); - KAI_UNUSED(dst); - KAI_UNUSED(dst_stride_row); - KAI_UNUSED(dst_stride_col); - KAI_UNUSED(scalar_min); - KAI_UNUSED(scalar_max); -#endif } +#endif // Architectural feature check diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h index 1f350fe97136d4ac949b2761fc0355c0b83ae537..c22b4d7121d6b84408d09eeb2eb09ab604a8b361 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h @@ -9,7 +9,6 @@ #include #endif #include -#include #ifdef __cplusplus extern "C" { diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.c b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.c index 2704c853ab3c39510cab5849c952d1e8ba266882..06adc07a8ae7d8d37ff64bf14915d2d713cf7020 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.c +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.c @@ -3,6 +3,9 @@ // // SPDX-License-Identifier: Apache-2.0 // +#if !defined(__ARM_FEATURE_MATMUL_INT8) +#error "I8mm extension required to compile this micro-kernel" +#else #include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.h" #include @@ -95,7 +98,6 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm( void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm( size_t m, size_t n, size_t k, const void* restrict lhs_packed, const void* restrict rhs_packed, float* restrict dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { -#if defined(__ARM_FEATURE_MATMUL_INT8) KAI_ASSERT(dst_stride_col == sizeof(float)); if (m == 0) { @@ -353,17 +355,5 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm( : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"); -#else - KAI_ASSERT(false); - KAI_UNUSED(m); - KAI_UNUSED(n); - KAI_UNUSED(k); - KAI_UNUSED(lhs_packed); - KAI_UNUSED(rhs_packed); - KAI_UNUSED(dst); - KAI_UNUSED(dst_stride_row); - KAI_UNUSED(dst_stride_col); - KAI_UNUSED(scalar_min); - KAI_UNUSED(scalar_max); -#endif } +#endif // Architectural feature check diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.h b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.h index 6e6363c7afb25c1b4c60b753f65ad2388cfd87d0..94c5f6c19d5a765b00cdd68a677d47db6cea10cd 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.h +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.h @@ -9,7 +9,6 @@ #include #endif #include -#include #ifdef __cplusplus extern "C" { diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.c b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.c index 6a40b814348e4f472ffcbb8f8b9663529e53d4dc..9d8153a98b0608584d433ccc2d403898d4d33282 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.c +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.c @@ -3,6 +3,9 @@ // // SPDX-License-Identifier: Apache-2.0 // +#if !defined(__ARM_FEATURE_MATMUL_INT8) +#error "I8mm extension required to compile this micro-kernel" +#else #include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h" #include @@ -95,7 +98,6 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm( void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm( size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { -#if defined(__ARM_FEATURE_MATMUL_INT8) KAI_ASSERT(dst_stride_col == sizeof(float)); if (m == 0) { @@ -729,17 +731,5 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm( : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x11", "x12", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"); -#else - KAI_ASSERT(false); - KAI_UNUSED(m); - KAI_UNUSED(n); - KAI_UNUSED(k); - KAI_UNUSED(lhs_packed); - KAI_UNUSED(rhs_packed); - KAI_UNUSED(dst); - KAI_UNUSED(dst_stride_row); - KAI_UNUSED(dst_stride_col); - KAI_UNUSED(scalar_min); - KAI_UNUSED(scalar_max); -#endif } +#endif // Architectural feature check diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h index 5d2e2d59c2f970a74074fe4ac1178587edb22f71..480f3c5f59b43b129beb43f3e7e4efd29359ec2d 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h @@ -9,7 +9,6 @@ #include #endif #include -#include #ifdef __cplusplus extern "C" {